diff --git a/chain/chain-primitives/src/error.rs b/chain/chain-primitives/src/error.rs index f394a31addf..866eb237794 100644 --- a/chain/chain-primitives/src/error.rs +++ b/chain/chain-primitives/src/error.rs @@ -199,7 +199,7 @@ pub enum Error { InvalidBlockMerkleRoot, /// Invalid split shard ids. #[error("Invalid Split Shard Ids when resharding. shard_id: {0}, parent_shard_id: {1}")] - InvalidSplitShardsIds(u64, u64), + InvalidSplitShardsIds(ShardId, ShardId), /// Someone is not a validator. Usually happens in signature verification #[error("Not A Validator: {0}")] NotAValidator(String), diff --git a/chain/chain/src/blocks_delay_tracker.rs b/chain/chain/src/blocks_delay_tracker.rs index 32bd73b236d..fa0803718fa 100644 --- a/chain/chain/src/blocks_delay_tracker.rs +++ b/chain/chain/src/blocks_delay_tracker.rs @@ -2,6 +2,7 @@ use near_async::time::{Clock, Instant, Utc}; use near_epoch_manager::EpochManagerAdapter; use near_primitives::block::{Block, Tip}; use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardLayout; use near_primitives::sharding::{ChunkHash, ShardChunkHeader}; use near_primitives::types::{BlockHeight, ShardId}; use near_primitives::views::{ @@ -289,7 +290,12 @@ impl BlocksDelayTracker { } } - pub fn finish_block_processing(&mut self, block_hash: &CryptoHash, new_head: Option) { + pub fn finish_block_processing( + &mut self, + shard_layout: &ShardLayout, + block_hash: &CryptoHash, + new_head: Option, + ) { if let Some(processed_block) = self.blocks.get_mut(&block_hash) { processed_block.processed_timestamp = Some(self.clock.now()); } @@ -297,10 +303,11 @@ impl BlocksDelayTracker { if let Some(processed_block) = self.blocks.get(&block_hash) { let chunks = processed_block.chunks.clone(); self.update_block_metrics(processed_block); - for (shard_id, chunk_hash) in chunks.into_iter().enumerate() { + for (shard_index, chunk_hash) in chunks.into_iter().enumerate() { if let Some(chunk_hash) = chunk_hash { if let Some(processed_chunk) = self.chunks.get(&chunk_hash) { - self.update_chunk_metrics(processed_chunk, shard_id as ShardId); + let shard_id = shard_layout.get_shard_id(shard_index); + self.update_chunk_metrics(processed_chunk, shard_id); } } } diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index e46ff494642..5b790cf7aa0 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -612,23 +612,26 @@ impl Chain { pub fn genesis_chunk_extra( &self, + shard_layout: &ShardLayout, shard_id: ShardId, genesis_protocol_version: ProtocolVersion, congestion_info: Option, ) -> Result { - let shard_index = shard_id as usize; + let shard_index = shard_layout.get_shard_index(shard_id); let state_root = *get_genesis_state_roots(self.chain_store.store())? .ok_or_else(|| Error::Other("genesis state roots do not exist in the db".to_owned()))? .get(shard_index) .ok_or_else(|| { - Error::Other(format!("genesis state root does not exist for shard {shard_index}")) + Error::Other(format!("genesis state root does not exist for shard id {shard_id} shard index {shard_index}")) })?; let gas_limit = self .genesis .chunks() .get(shard_index) .ok_or_else(|| { - Error::Other(format!("genesis chunk does not exist for shard {shard_index}")) + Error::Other(format!( + "genesis chunk does not exist for shard {shard_id} shard index {shard_index}" + )) })? .gas_limit(); Ok(Self::create_genesis_chunk_extra( @@ -780,13 +783,16 @@ impl Chain { Ok(None) } else { debug!(target: "chain", "Downloading state for {:?}, I'm {:?}", shards_to_state_sync, me); + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; let state_sync_info = StateSyncInfo { epoch_tail_hash: *block.header().hash(), shards: shards_to_state_sync .iter() .map(|shard_id| { - let chunk = &prev_block.chunks()[*shard_id as usize]; + let shard_index = shard_layout.get_shard_index(*shard_id); + let chunk = &prev_block.chunks()[shard_index]; ShardInfo(*shard_id, chunk.chunk_hash()) }) .collect(), @@ -812,14 +818,18 @@ impl Chain { genesis_block: &Block, block: &Block, ) -> Result<(), Error> { - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + let epoch_id = block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; + + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); if chunk_header.height_created() == genesis_block.header().height() { // Special case: genesis chunks can be in non-genesis blocks and don't have a signature // We must verify that content matches and signature is empty. // TODO: this code will not work when genesis block has different number of chunks as the current block // https://github.com/near/nearcore/issues/4908 let chunks = genesis_block.chunks(); - let genesis_chunk = chunks.get(shard_id); + let genesis_chunk = chunks.get(shard_index); let genesis_chunk = genesis_chunk.ok_or_else(|| { Error::InvalidChunk(format!( "genesis chunk not found for shard {}, genesis block has {} chunks", @@ -841,7 +851,7 @@ impl Chain { ))); } } else if chunk_header.height_created() == block.header().height() { - if chunk_header.shard_id() != shard_id as ShardId { + if chunk_header.shard_id() != shard_id { return Err(Error::InvalidShardId(chunk_header.shard_id())); } if !epoch_manager.verify_chunk_header_signature( @@ -1301,15 +1311,19 @@ impl Chain { } let mut missing = vec![]; let block_height = block.header().height(); - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); // Check if any chunks are invalid in this block. if let Some(encoded_chunk) = self.chain_store.is_invalid_chunk(&chunk_header.chunk_hash())? { let merkle_paths = Block::compute_chunk_headers_root(block.chunks().iter()).1; - let merkle_proof = merkle_paths - .get(shard_id) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))?; + let merkle_proof = + merkle_paths.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?; let chunk_proof = ChunkProofs { block_header: borsh::to_vec(&block.header()).expect("Failed to serialize"), merkle_proof: merkle_proof.clone(), @@ -1319,7 +1333,6 @@ impl Chain { }; return Err(Error::InvalidChunkProofs(Box::new(chunk_proof))); } - let shard_id = shard_id as ShardId; if chunk_header.is_new_chunk(block_height) { let chunk_hash = chunk_header.chunk_hash(); @@ -1897,15 +1910,20 @@ impl Chain { height = block.header().height()) .entered(); + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let prev_head = self.chain_store.head()?; let is_caught_up = block_preprocess_info.is_caught_up; let provenance = block_preprocess_info.provenance.clone(); let block_start_processing_time = block_preprocess_info.block_start_processing_time; // TODO(#8055): this zip relies on the ordering of the apply_results. + // TODO(wacban): do the above todo for (shard_id, apply_result) in apply_results.iter() { + let shard_index = shard_layout.get_shard_index(*shard_id); if let Err(err) = apply_result { if err.is_bad_data() { - let chunk = block.chunks()[*shard_id as usize].clone(); + let chunk = block.chunks()[shard_index].clone(); block_processing_artifacts.invalid_chunks.push(chunk); } } @@ -1950,7 +1968,7 @@ impl Chain { // during catchup of this block. care_about_shard }; - tracing::debug!(target: "chain", shard_id, need_storage_update, "Updating storage"); + tracing::debug!(target: "chain", ?shard_id, need_storage_update, "Updating storage"); if need_storage_update { // TODO(#12019): consider adding to catchup flow. @@ -2004,7 +2022,12 @@ impl Chain { .as_seconds_f64() .max(0.0), ); - self.blocks_delay_tracker.finish_block_processing(&block_hash, new_head.clone()); + let shard_layout = self.epoch_manager.get_shard_layout(epoch_id)?; + self.blocks_delay_tracker.finish_block_processing( + &shard_layout, + &block_hash, + new_head.clone(), + ); timer.observe_duration(); let _timer = CryptoHashTimer::new_with_start( @@ -2449,12 +2472,16 @@ impl Chain { if sync_block_epoch_id == sync_prev_block.header().epoch_id() { return Err(sync_hash_not_first_hash(sync_hash)); } + + let epoch_id = sync_prev_block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + // Chunk header here is the same chunk header as at the `current` height. let sync_prev_hash = sync_prev_block.hash(); let chunks = sync_prev_block.chunks(); - let chunk_header = chunks - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))?; + let chunk_header = + chunks.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?; let (chunk_headers_root, chunk_proofs) = merklize( &sync_prev_block .chunks() @@ -2467,10 +2494,8 @@ impl Chain { assert_eq!(&chunk_headers_root, sync_prev_block.header().chunk_headers_root()); let chunk = self.get_chunk_clone_from_header(chunk_header)?; - let chunk_proof = chunk_proofs - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? - .clone(); + let chunk_proof = + chunk_proofs.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?.clone(); let block_header = self.get_block_header_on_chain_by_height(&sync_hash, chunk_header.height_included())?; @@ -2481,8 +2506,8 @@ impl Chain { Ok(prev_block) => { let prev_chunk_header = prev_block .chunks() - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .clone(); let (prev_chunk_headers_root, prev_chunk_proofs) = merklize( &prev_block @@ -2496,8 +2521,8 @@ impl Chain { assert_eq!(&prev_chunk_headers_root, prev_block.header().chunk_headers_root()); let prev_chunk_proof = prev_chunk_proofs - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .clone(); let prev_chunk_height_included = prev_chunk_header.height_included(); @@ -2548,18 +2573,18 @@ impl Chain { let ReceiptProof(receipts, shard_proof) = receipt_proof; let ShardProof { from_shard_id, to_shard_id: _, proof } = shard_proof; let receipts_hash = CryptoHash::hash_borsh(ReceiptList(shard_id, receipts)); - let from_shard_id = *from_shard_id as usize; + let from_shard_index = shard_layout.get_shard_index(*from_shard_id); - let root_proof = block.chunks()[from_shard_id].prev_outgoing_receipts_root(); + let root_proof = block.chunks()[from_shard_index].prev_outgoing_receipts_root(); root_proofs_cur - .push(RootProof(root_proof, block_receipts_proofs[from_shard_id].clone())); + .push(RootProof(root_proof, block_receipts_proofs[from_shard_index].clone())); // Make sure we send something reasonable. assert_eq!(block_header.prev_chunk_outgoing_receipts_root(), &block_receipts_root); assert!(verify_path(root_proof, proof, &receipts_hash)); assert!(verify_path( block_receipts_root, - &block_receipts_proofs[from_shard_id], + &block_receipts_proofs[from_shard_index], &root_proof, )); } @@ -2632,7 +2657,7 @@ impl Chain { let _span = tracing::debug_span!( target: "sync", "get_state_response_part", - shard_id, + ?shard_id, part_id, ?sync_hash) .entered(); @@ -2647,6 +2672,7 @@ impl Chain { .log_storage_error("block has already been checked for existence")?; let header = block.header(); let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(epoch_id)?; let shard_ids = self.epoch_manager.shard_ids(epoch_id)?; if !shard_ids.contains(&shard_id) { return Err(shard_id_out_of_bounds(shard_id)); @@ -2655,10 +2681,11 @@ impl Chain { if epoch_id == prev_block.header().epoch_id() { return Err(sync_hash_not_first_hash(sync_hash)); } + let shard_index = shard_layout.get_shard_index(shard_id); let state_root = prev_block .chunks() - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .prev_state_root(); let prev_hash = *prev_block.hash(); let prev_prev_hash = *prev_block.header().prev_hash(); @@ -3105,9 +3132,14 @@ impl Chain { ) -> Result<(), Error> { if !validate_transactions_order(chunk.transactions()) { let merkle_paths = Block::compute_chunk_headers_root(block.chunks().iter()).1; + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_id = chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_proof = ChunkProofs { block_header: borsh::to_vec(&block.header()).expect("Failed to serialize"), - merkle_proof: merkle_paths[chunk.shard_id() as usize].clone(), + merkle_proof: merkle_paths[shard_index].clone(), chunk: MaybeEncodedShardChunk::Decoded(chunk.clone()).into(), }; return Err(Error::InvalidChunkProofs(Box::new(chunk_proof))); @@ -3423,12 +3455,14 @@ impl Chain { block: &Block, chunk_header: &ShardChunkHeader, ) -> Result { - let chunk_shard_id = chunk_header.shard_id(); + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_id = chunk_header.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let prev_merkle_proofs = Block::compute_chunk_headers_root(prev_block.chunks().iter()).1; let merkle_proofs = Block::compute_chunk_headers_root(block.chunks().iter()).1; - let prev_chunk = self - .get_chunk_clone_from_header(&prev_block.chunks()[chunk_shard_id as usize].clone()) - .unwrap(); + let prev_chunk = + self.get_chunk_clone_from_header(&prev_block.chunks()[shard_index].clone()).unwrap(); // TODO (#6316): enable storage proof generation // let prev_chunk_header = &prev_block.chunks()[chunk_shard_id as usize]; @@ -3479,8 +3513,8 @@ impl Chain { Ok(ChunkState { prev_block_header: borsh::to_vec(&prev_block.header())?, block_header: borsh::to_vec(&block.header())?, - prev_merkle_proof: prev_merkle_proofs[chunk_shard_id as usize].clone(), - merkle_proof: merkle_proofs[chunk_shard_id as usize].clone(), + prev_merkle_proof: prev_merkle_proofs[shard_index].clone(), + merkle_proof: merkle_proofs[shard_index].clone(), prev_chunk, chunk_header: chunk_header.clone(), partial_state: PartialState::TrieValues(vec![]), @@ -3590,13 +3624,17 @@ impl Chain { let prev_chunk_headers = Chain::get_prev_chunk_headers(self.epoch_manager.as_ref(), prev_block)?; + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let mut maybe_jobs = vec![]; - for (shard_id, (chunk_header, prev_chunk_header)) in + for (shard_index, (chunk_header, prev_chunk_header)) in block.chunks().iter().zip(prev_chunk_headers.iter()).enumerate() { // XXX: This is a bit questionable -- sandbox state patching works // only for a single shard. This so far has been enough. let state_patch = state_patch.take(); + let shard_id = shard_layout.get_shard_id(shard_index); let storage_context = StorageContext { storage_data_source: StorageDataSource::Db, state_patch }; @@ -3606,7 +3644,7 @@ impl Chain { prev_block, chunk_header, prev_chunk_header, - shard_id as ShardId, + shard_id, mode, incoming_receipts, storage_context, @@ -3621,10 +3659,14 @@ impl Chain { Ok(None) => {} Err(err) => { if err.is_bad_data() { + let epoch_id = block.header().epoch_id(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_header = block .chunks() - .get(shard_id) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? .clone(); invalid_chunks.push(chunk_header); } @@ -3674,7 +3716,7 @@ impl Chain { prev_chunk_header: &ShardChunkHeader, shard_id: ShardId, mode: ApplyChunksMode, - incoming_receipts: &HashMap>, + incoming_receipts: &HashMap>, storage_context: StorageContext, ) -> Result, Error> { let _span = tracing::debug_span!(target: "chain", "get_update_shard_job").entered(); @@ -3713,7 +3755,7 @@ impl Chain { ?err, prev_block_hash=?prev_hash, block_hash=?block.header().hash(), - shard_id, + ?shard_id, prev_chunk_height_included, ?prev_chunk_extra, ?chunk_header, @@ -3888,6 +3930,7 @@ fn get_genesis_congestion_infos_impl( let genesis_prev_hash = CryptoHash::default(); let genesis_epoch_id = epoch_manager.get_epoch_id_from_prev_block(&genesis_prev_hash)?; let genesis_protocol_version = epoch_manager.get_epoch_protocol_version(&genesis_epoch_id)?; + let genesis_shard_layout = epoch_manager.get_shard_layout(&genesis_epoch_id)?; // If congestion control is not enabled at the genesis block, we return None (congestion info) for each shard. if !ProtocolFeature::CongestionControl.enabled(genesis_protocol_version) { return Ok(std::iter::repeat(None).take(state_roots.len()).collect()); @@ -3900,8 +3943,8 @@ fn get_genesis_congestion_infos_impl( } let mut new_infos = vec![]; - for (shard_id, &state_root) in state_roots.iter().enumerate() { - let shard_id = shard_id as ShardId; + for (shard_index, &state_root) in state_roots.iter().enumerate() { + let shard_id = genesis_shard_layout.get_shard_id(shard_index); let congestion_info = get_genesis_congestion_info( runtime, genesis_protocol_version, @@ -4323,9 +4366,9 @@ impl Chain { let block = self.get_block(&block_hash)?; let chunks = block.chunks(); for &shard_id in shard_ids.iter() { - let chunk_header = &chunks - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_header = + &chunks.get(shard_index).ok_or_else(|| Error::InvalidShardId(shard_id))?; if chunk_header.height_included() == block.header().height() { return Ok(Some((block_hash, shard_id))); } @@ -4475,11 +4518,12 @@ impl Chain { ) -> Result, Error> { let epoch_id = epoch_manager.get_epoch_id_from_prev_block(prev_block.hash())?; let shard_ids = epoch_manager.shard_ids(&epoch_id)?; + let prev_shard_ids = epoch_manager.get_prev_shard_ids(prev_block.hash(), shard_ids)?; - let chunks = prev_block.chunks(); + let prev_chunks = prev_block.chunks(); Ok(prev_shard_ids .into_iter() - .map(|shard_id| chunks.get(shard_id as usize).unwrap().clone()) + .map(|(_, shard_index)| prev_chunks.get(shard_index).unwrap().clone()) .collect()) } @@ -4488,11 +4532,12 @@ impl Chain { prev_block: &Block, shard_id: ShardId, ) -> Result { - let prev_shard_id = epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; + let (prev_shard_id, prev_shard_index) = + epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; Ok(prev_block .chunks() - .get(prev_shard_id as usize) - .ok_or(Error::InvalidShardId(shard_id))? + .get(prev_shard_index) + .ok_or(Error::InvalidShardId(prev_shard_id))? .clone()) } diff --git a/chain/chain/src/chain_update.rs b/chain/chain/src/chain_update.rs index 9d0a089d767..d1fd7bbf170 100644 --- a/chain/chain/src/chain_update.rs +++ b/chain/chain/src/chain_update.rs @@ -209,7 +209,7 @@ impl<'a> ChainUpdate<'a> { let prev_hash = block.header().prev_hash(); let results = apply_chunks_results.into_iter().map(|(shard_id, x)| { if let Err(err) = &x { - warn!(target: "chain", shard_id, hash = %block.hash(), %err, "Error in applying chunk for block"); + warn!(target: "chain", ?shard_id, hash = %block.hash(), %err, "Error in applying chunk for block"); } x }).collect::, Error>>()?; @@ -465,7 +465,7 @@ impl<'a> ChainUpdate<'a> { shard_state_header: ShardStateSyncResponseHeader, ) -> Result { let _span = - tracing::debug_span!(target: "sync", "chain_update_set_state_finalize", shard_id, ?sync_hash).entered(); + tracing::debug_span!(target: "sync", "chain_update_set_state_finalize", ?shard_id, ?sync_hash).entered(); let (chunk, incoming_receipts_proofs) = match shard_state_header { ShardStateSyncResponseHeader::V1(shard_state_header) => ( ShardChunk::V1(shard_state_header.chunk), @@ -596,7 +596,7 @@ impl<'a> ChainUpdate<'a> { sync_hash: CryptoHash, ) -> Result { let _span = - tracing::debug_span!(target: "sync", "set_state_finalize_on_height", height, shard_id) + tracing::debug_span!(target: "sync", "set_state_finalize_on_height", height, ?shard_id) .entered(); let block_header_result = self.chain_store_update.get_block_header_on_chain_by_height(&sync_hash, height); diff --git a/chain/chain/src/garbage_collection.rs b/chain/chain/src/garbage_collection.rs index 876617fb357..777d83edaef 100644 --- a/chain/chain/src/garbage_collection.rs +++ b/chain/chain/src/garbage_collection.rs @@ -585,9 +585,11 @@ impl<'a> ChainStoreUpdate<'a> { let block = self.get_block(&block_hash).expect("block data is not expected to be already cleaned"); let height = block.header().height(); + let epoch_id = block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).expect("epoch id must exist"); // 2. Delete shard_id-indexed data (Receipts, State Headers and Parts, etc.) - for shard_id in 0..block.header().chunk_mask().len() as ShardId { + for shard_id in shard_layout.shard_ids() { let block_shard_id = get_block_shard_id(&block_hash, shard_id); self.gc_outgoing_receipts(&block_hash, shard_id); self.gc_col(DBCol::IncomingReceipts, &block_shard_id); @@ -678,11 +680,11 @@ impl<'a> ChainStoreUpdate<'a> { self.get_block(&block_hash).expect("block data is not expected to be already cleaned"); let epoch_id = block.header().epoch_id(); - let head_height = block.header().height(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).expect("epoch id must exist"); // 1. Delete shard_id-indexed data (TrieChanges, Receipts, ChunkExtra, State Headers and Parts, FlatStorage data) - for shard_id in 0..block.header().chunk_mask().len() as ShardId { + for shard_id in shard_layout.shard_ids() { let shard_uid = epoch_manager.shard_id_to_uid(shard_id, epoch_id).unwrap(); let block_shard_id = get_block_shard_uid(&block_hash, &shard_uid); @@ -833,6 +835,8 @@ impl<'a> ChainStoreUpdate<'a> { for chunk_header in block.chunks().iter().filter(|h| h.height_included() == block.header().height()) { + // It is ok to use the shard id from the header because it is a new + // chunk. An old chunk may have the shard id from the parent shard. let shard_id = chunk_header.shard_id(); let outcome_ids = self.chain_store().get_outcomes_by_block_hash_and_shard_id(block_hash, shard_id)?; diff --git a/chain/chain/src/migrations.rs b/chain/chain/src/migrations.rs index f778ac4a245..afdbf7f3f43 100644 --- a/chain/chain/src/migrations.rs +++ b/chain/chain/src/migrations.rs @@ -29,7 +29,7 @@ pub fn check_if_block_is_first_with_chunk_of_version( if is_first_epoch_with_protocol_version(epoch_manager, prev_block_hash)? { // Compare only epochs because we already know that current epoch is the first one with current protocol version // convert shard id to shard id of previous epoch because number of shards may change - let shard_id = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0]; + let (shard_id, _) = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0]; let prev_epoch_id = chain_store.get_epoch_id_of_last_block_with_chunk( epoch_manager, prev_block_hash, diff --git a/chain/chain/src/runtime/migrations.rs b/chain/chain/src/runtime/migrations.rs index c2855ac3a59..827fc3f934f 100644 --- a/chain/chain/src/runtime/migrations.rs +++ b/chain/chain/src/runtime/migrations.rs @@ -33,6 +33,7 @@ mod tests { use near_mainnet_res::mainnet_restored_receipts; use near_mainnet_res::mainnet_storage_usage_delta; use near_primitives::hash::hash; + use near_primitives::types::new_shard_id_tmp; #[test] fn test_migration_data() { @@ -55,7 +56,10 @@ mod tests { "48ZMJukN7RzvyJSW9MJ5XmyQkQFfjy2ZxPRaDMMHqUcT" ); let mainnet_migration_data = load_migration_data(near_primitives::chains::MAINNET); - assert_eq!(mainnet_migration_data.restored_receipts.get(&0u64).unwrap().len(), 383); + assert_eq!( + mainnet_migration_data.restored_receipts.get(&new_shard_id_tmp(0)).unwrap().len(), + 383 + ); let testnet_migration_data = load_migration_data(near_primitives::chains::TESTNET); assert!(testnet_migration_data.restored_receipts.is_empty()); } diff --git a/chain/chain/src/runtime/mod.rs b/chain/chain/src/runtime/mod.rs index 3713aa3e8bf..c901466c41a 100644 --- a/chain/chain/src/runtime/mod.rs +++ b/chain/chain/src/runtime/mod.rs @@ -29,8 +29,8 @@ use near_primitives::state_part::PartId; use near_primitives::transaction::SignedTransaction; use near_primitives::trie_key::TrieKey; use near_primitives::types::{ - AccountId, Balance, BlockHeight, EpochHeight, EpochId, EpochInfoProvider, Gas, MerkleHash, - ShardId, StateChangeCause, StateRoot, StateRootNode, + new_shard_id_tmp, AccountId, Balance, BlockHeight, EpochHeight, EpochId, EpochInfoProvider, + Gas, MerkleHash, ShardId, StateChangeCause, StateRoot, StateRootNode, }; use near_primitives::version::{ProtocolFeature, ProtocolVersion}; use near_primitives::views::{ @@ -223,7 +223,7 @@ impl NightshadeRuntime { epoch_manager.get_epoch_id_from_prev_block(prev_hash).map_err(Error::from)?; let shard_version = epoch_manager.get_shard_layout(&epoch_id).map_err(Error::from)?.version(); - Ok(ShardUId { version: shard_version, shard_id: shard_id as u32 }) + Ok(ShardUId { version: shard_version, shard_id: new_shard_id_tmp(shard_id) as u32 }) } fn get_shard_uid_from_epoch_id( @@ -234,7 +234,7 @@ impl NightshadeRuntime { let epoch_manager = self.epoch_manager.read(); let shard_version = epoch_manager.get_shard_layout(epoch_id).map_err(Error::from)?.version(); - Ok(ShardUId { version: shard_version, shard_id: shard_id as u32 }) + Ok(ShardUId { version: shard_version, shard_id: new_shard_id_tmp(shard_id) as u32 }) } fn account_id_to_shard_uid( @@ -566,7 +566,7 @@ impl NightshadeRuntime { target: "runtime", "obtain_state_part", part_id = part_id.idx, - shard_id, + ?shard_id, %prev_hash, num_parts = part_id.total) .entered(); @@ -953,7 +953,7 @@ impl RuntimeAdapter for NightshadeRuntime { } } - #[instrument(target = "runtime", level = "info", skip_all, fields(shard_id = chunk.shard_id))] + #[instrument(target = "runtime", level = "info", skip_all, fields(shard_id = ?chunk.shard_id))] fn apply_chunk( &self, storage_config: RuntimeStorageConfig, @@ -1191,7 +1191,7 @@ impl RuntimeAdapter for NightshadeRuntime { target: "runtime", "obtain_state_part", part_id = part_id.idx, - shard_id, + ?shard_id, %prev_hash, ?state_root, num_parts = part_id.total) @@ -1369,7 +1369,7 @@ fn chunk_tx_gas_limit( protocol_version: u32, runtime_config: &RuntimeConfig, prev_block: &PrepareTransactionsBlockContext, - shard_id: u64, + shard_id: ShardId, gas_limit: u64, ) -> u64 { if !ProtocolFeature::CongestionControl.enabled(protocol_version) { diff --git a/chain/chain/src/runtime/tests.rs b/chain/chain/src/runtime/tests.rs index 494778e2b7d..10b2ed7f971 100644 --- a/chain/chain/src/runtime/tests.rs +++ b/chain/chain/src/runtime/tests.rs @@ -214,12 +214,13 @@ impl TestEnv { // TODO(congestion_control): pass down prev block info and read congestion info from there // For now, just use default. let prev_block_hash = self.head.last_block_hash; - let state_root = self.state_roots[shard_id as usize]; + let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap(); + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); + let state_root = self.state_roots[shard_index]; let gas_limit = u64::MAX; let height = self.head.height + 1; let block_timestamp = 0; - let epoch_id = - self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap_or_default(); let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id).unwrap(); let gas_price = self.runtime.genesis_config.min_gas_price; let congestion_info = if !ProtocolFeature::CongestionControl.enabled(protocol_version) { @@ -316,19 +317,21 @@ impl TestEnv { ) { let new_hash = hash(&[(self.head.height + 1) as u8]); let shard_ids = self.epoch_manager.shard_ids(&self.head.epoch_id).unwrap(); + let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap(); assert_eq!(transactions.len(), shard_ids.len()); assert_eq!(chunk_mask.len(), shard_ids.len()); let mut all_proposals = vec![]; let mut all_receipts = vec![]; for shard_id in shard_ids { + let shard_index = shard_layout.get_shard_index(shard_id); let (state_root, proposals, receipts) = self.update_runtime( shard_id, new_hash, - &transactions[shard_id as usize], + &transactions[shard_index], self.last_receipts.get(&shard_id).map_or(&[], |v| v.as_slice()), challenges_result.clone(), ); - self.state_roots[shard_id as usize] = state_root; + self.state_roots[shard_index] = state_root; all_receipts.extend(receipts); all_proposals.append(&mut proposals.clone()); self.last_shard_proposals.insert(shard_id, proposals); @@ -391,9 +394,11 @@ impl TestEnv { &self.head.epoch_id, ) .unwrap(); + let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &self.head.epoch_id).unwrap(); self.runtime - .view_account(&shard_uid, self.state_roots[shard_id as usize], account_id) + .view_account(&shard_uid, self.state_roots[shard_index], account_id) .unwrap() .into() } @@ -727,9 +732,12 @@ fn test_state_sync() { let block_hash = hash(&[env.head.height as u8]); let state_part = env .runtime - .obtain_state_part(0, &block_hash, &env.state_roots[0], PartId::new(0, 1)) + .obtain_state_part(new_shard_id_tmp(0), &block_hash, &env.state_roots[0], PartId::new(0, 1)) + .unwrap(); + let root_node = env + .runtime + .get_state_root_node(new_shard_id_tmp(0), &block_hash, &env.state_roots[0]) .unwrap(); - let root_node = env.runtime.get_state_root_node(0, &block_hash, &env.state_roots[0]).unwrap(); let mut new_env = TestEnv::new(vec![validators], 2, false); for i in 1..=2 { let prev_hash = hash(&[new_env.head.height as u8]); @@ -786,7 +794,13 @@ fn test_state_sync() { let epoch_id = &new_env.head.epoch_id; new_env .runtime - .apply_state_part(0, &env.state_roots[0], PartId::new(0, 1), &state_part, epoch_id) + .apply_state_part( + new_shard_id_tmp(0), + &env.state_roots[0], + PartId::new(0, 1), + &state_part, + epoch_id, + ) .unwrap(); new_env.state_roots[0] = env.state_roots[0]; for _ in 3..=5 { @@ -827,9 +841,9 @@ fn test_get_validator_info() { let height = env.head.height; let em = env.runtime.epoch_manager.read(); let bp = em.get_block_producer_info(&epoch_id, height).unwrap(); - let cp = em.get_chunk_producer_info(&epoch_id, height, 0).unwrap(); + let cp = em.get_chunk_producer_info(&epoch_id, height, new_shard_id_tmp(0)).unwrap(); let stateless_validators = - em.get_chunk_validator_assignments(&epoch_id, 0, height).ok(); + em.get_chunk_validator_assignments(&epoch_id, new_shard_id_tmp(0), height).ok(); if let Some(vs) = stateless_validators { if vs.contains(&validators[0]) { @@ -876,7 +890,7 @@ fn test_get_validator_info() { public_key: block_producers[0].public_key(), is_slashed: false, stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], num_produced_blocks: expected_blocks[0], num_expected_blocks: expected_blocks[0], num_produced_chunks: expected_chunks[0], @@ -893,7 +907,7 @@ fn test_get_validator_info() { public_key: block_producers[1].public_key(), is_slashed: false, stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], num_produced_blocks: expected_blocks[1], num_expected_blocks: expected_blocks[1], num_produced_chunks: expected_chunks[1], @@ -911,13 +925,13 @@ fn test_get_validator_info() { account_id: "test1".parse().unwrap(), public_key: block_producers[0].public_key(), stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], }, NextEpochValidatorInfo { account_id: "test2".parse().unwrap(), public_key: block_producers[1].public_key(), stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], }, ]; let response = env @@ -988,7 +1002,7 @@ fn test_get_validator_info() { account_id: "test2".parse().unwrap(), public_key: block_producers[1].public_key(), stake: TESTING_INIT_STAKE, - shards: vec![0], + shards: vec![new_shard_id_tmp(0)], }] ); assert!(response.current_proposals.is_empty()); @@ -1461,13 +1475,13 @@ fn test_flat_state_usage() { let env = TestEnv::new(vec![vec!["test1".parse().unwrap()]], 4, false); let trie = env .runtime - .get_trie_for_shard(0, &env.head.prev_block_hash, Trie::EMPTY_ROOT, true) + .get_trie_for_shard(new_shard_id_tmp(0), &env.head.prev_block_hash, Trie::EMPTY_ROOT, true) .unwrap(); assert!(trie.has_flat_storage_chunk_view()); let trie = env .runtime - .get_view_trie_for_shard(0, &env.head.prev_block_hash, Trie::EMPTY_ROOT) + .get_view_trie_for_shard(new_shard_id_tmp(0), &env.head.prev_block_hash, Trie::EMPTY_ROOT) .unwrap(); assert!(!trie.has_flat_storage_chunk_view()); } @@ -1505,9 +1519,14 @@ fn test_trie_and_flat_state_equality() { // - using view state, which should never use flat state let head_prev_block_hash = env.head.prev_block_hash; let state_root = env.state_roots[0]; - let state = env.runtime.get_trie_for_shard(0, &head_prev_block_hash, state_root, true).unwrap(); - let view_state = - env.runtime.get_view_trie_for_shard(0, &head_prev_block_hash, state_root).unwrap(); + let state = env + .runtime + .get_trie_for_shard(new_shard_id_tmp(0), &head_prev_block_hash, state_root, true) + .unwrap(); + let view_state = env + .runtime + .get_view_trie_for_shard(new_shard_id_tmp(0), &head_prev_block_hash, state_root) + .unwrap(); let trie_key = TrieKey::Account { account_id: validators[1].clone() }; let key = trie_key.to_vec(); @@ -1651,7 +1670,7 @@ fn prepare_transactions( transaction_groups: &mut dyn TransactionGroupIterator, storage_config: RuntimeStorageConfig, ) -> Result { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let block = chain.get_block(&env.head.prev_block_hash).unwrap(); let congestion_info = block.block_congestion_info(); @@ -1773,7 +1792,7 @@ fn test_prepare_transactions_empty_storage_proof() { #[test] #[cfg_attr(not(feature = "test_features"), ignore)] fn test_storage_proof_garbage() { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let signer = create_test_signer("test1"); let env = TestEnv::new(vec![vec![signer.validator_id().clone()]], 100, false); let garbage_size_mb = 50usize; diff --git a/chain/chain/src/state_snapshot_actor.rs b/chain/chain/src/state_snapshot_actor.rs index 27756e4796f..c56d39016d3 100644 --- a/chain/chain/src/state_snapshot_actor.rs +++ b/chain/chain/src/state_snapshot_actor.rs @@ -5,7 +5,7 @@ use near_performance_metrics_macros::perf; use near_primitives::block::Block; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardUId; -use near_primitives::types::{EpochHeight, ShardId}; +use near_primitives::types::EpochHeight; use near_store::flat::FlatStorageManager; use near_store::ShardTries; use std::sync::Arc; @@ -92,7 +92,7 @@ impl StateSnapshotActor { NetworkRequests::SnapshotHostInfo { sync_hash: prev_block_hash, epoch_height, - shards: res_shard_uids.iter().map(|uid| uid.shard_id as ShardId).collect(), + shards: res_shard_uids.iter().map(|uid| uid.shard_id.into()).collect(), }, )); } diff --git a/chain/chain/src/stateless_validation/chunk_endorsement.rs b/chain/chain/src/stateless_validation/chunk_endorsement.rs index 1531e7d65eb..d2f0c1eb28a 100644 --- a/chain/chain/src/stateless_validation/chunk_endorsement.rs +++ b/chain/chain/src/stateless_validation/chunk_endorsement.rs @@ -43,8 +43,10 @@ pub fn validate_chunk_endorsements_in_block( } let epoch_id = epoch_manager.get_epoch_id_from_prev_block(block.header().prev_hash())?; + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; for (chunk_header, signatures) in block.chunks().iter().zip(block.chunk_endorsements()) { let shard_id = chunk_header.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); // For old chunks, we optimize the block by not including the chunk endorsements. if chunk_header.height_included() != block.header().height() { if !signatures.is_empty() { @@ -117,14 +119,14 @@ pub fn validate_chunk_endorsements_in_block( // Validate the chunk endorsements bitmap (if present) in the block header against the endorsement signatures in the body. if let Some(endorsements_bitmap) = endorsements_bitmap { // Bitmap's length must be equal to the min bytes needed to encode one bit per validator assignment. - if endorsements_bitmap.len(shard_id).unwrap() != signatures.len().div_ceil(8) * 8 { + if endorsements_bitmap.len(shard_index).unwrap() != signatures.len().div_ceil(8) * 8 { return Err(Error::InvalidChunkEndorsementBitmap(format!( "Bitmap's length {} is inconsistent with the number of signatures {} for shard {} ", - endorsements_bitmap.len(shard_id).unwrap(), signatures.len(), shard_id, + endorsements_bitmap.len(shard_index).unwrap(), signatures.len(), shard_id, ))); } // Bits in the bitmap must match the existence of signature for the corresponding validator in the body. - for (bit, signature) in endorsements_bitmap.iter(shard_id).zip(signatures.iter()) { + for (bit, signature) in endorsements_bitmap.iter(shard_index).zip(signatures.iter()) { if bit != signature.is_some() { return Err(Error::InvalidChunkEndorsementBitmap( format!("Chunk endorsement bit in header does not match endorsement in body. shard={}, bit={}, signature={}", @@ -132,7 +134,7 @@ pub fn validate_chunk_endorsements_in_block( } } // All extra positions after the assignments must be left as false. - for value in endorsements_bitmap.iter(shard_id).skip(signatures.len()) { + for value in endorsements_bitmap.iter(shard_index).skip(signatures.len()) { if value { return Err(Error::InvalidChunkEndorsementBitmap( format!("Extra positions in the bitmap after {} validator assignments are not all false for shard {}", @@ -157,6 +159,7 @@ pub fn validate_chunk_endorsements_in_header( ))); }; let epoch_id = epoch_manager.get_epoch_id_from_prev_block(header.prev_hash())?; + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; let shard_ids = epoch_manager.get_shard_layout(&epoch_id)?.shard_ids().collect_vec(); if chunk_endorsements.num_shards() != shard_ids.len() { return Err(Error::InvalidChunkEndorsementBitmap( @@ -165,12 +168,13 @@ pub fn validate_chunk_endorsements_in_header( } let chunk_mask = header.chunk_mask(); for shard_id in shard_ids.into_iter() { + let shard_index = shard_layout.get_shard_index(shard_id); // For old chunks, we optimize the block and its header by not including the chunk endorsements and // corresponding bitmaps. Thus, we expect that the bitmap is empty for shard with no new chunk. - if chunk_mask[shard_id as usize] != (chunk_endorsements.len(shard_id).unwrap() > 0) { + if chunk_mask[shard_index] != (chunk_endorsements.len(shard_index).unwrap() > 0) { return Err(Error::InvalidChunkEndorsementBitmap(format!( "Bitmap must be non-empty iff shard {} has new chunk in the block. Chunk mask={}, Bitmap length={}", - shard_id, chunk_mask[shard_id as usize], chunk_endorsements.len(shard_id).unwrap(), + shard_id, chunk_mask[shard_index], chunk_endorsements.len(shard_index).unwrap(), ))); } } diff --git a/chain/chain/src/stateless_validation/chunk_validation.rs b/chain/chain/src/stateless_validation/chunk_validation.rs index 8100d159e16..c796696f989 100644 --- a/chain/chain/src/stateless_validation/chunk_validation.rs +++ b/chain/chain/src/stateless_validation/chunk_validation.rs @@ -53,6 +53,8 @@ impl MainTransition { pub fn shard_id(&self) -> ShardId { match self { Self::Genesis { shard_id, .. } => *shard_id, + // It is ok to use the shard id from the header because it is a new + // chunk. An old chunk may have the shard id from the parent shard. Self::NewChunk(data) => data.chunk_header.shard_id(), } } @@ -112,7 +114,13 @@ pub fn pre_validate_chunk_state_witness( runtime_adapter: &dyn RuntimeAdapter, ) -> Result { let store = chain.chain_store(); + let epoch_id = state_witness.epoch_id; + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; + + // It is ok to use the shard id from the header because it is a new + // chunk. An old chunk may have the shard id from the parent shard. let shard_id = state_witness.chunk_header.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); // First, go back through the blockchain history to locate the last new chunk // and last last new chunk for the shard. @@ -128,7 +136,7 @@ pub fn pre_validate_chunk_state_witness( loop { let block = store.get_block(&block_hash)?; let chunks = block.chunks(); - let Some(chunk) = chunks.get(shard_id as usize) else { + let Some(chunk) = chunks.get(shard_index) else { return Err(Error::InvalidChunkStateWitness(format!( "Shard {} does not exist in block {:?}", shard_id, block_hash @@ -167,8 +175,7 @@ pub fn pre_validate_chunk_state_witness( let last_chunk_block = blocks_after_last_last_chunk.first().ok_or_else(|| { Error::Other("blocks_after_last_last_chunk is empty, this should be impossible!".into()) })?; - let last_new_chunk_tx_root = - last_chunk_block.chunks().get(shard_id as usize).unwrap().tx_root(); + let last_new_chunk_tx_root = last_chunk_block.chunks().get(shard_index).unwrap().tx_root(); if last_new_chunk_tx_root != tx_root_from_state_witness { return Err(Error::InvalidChunkStateWitness(format!( "Transaction root {:?} does not match expected transaction root {:?}", @@ -216,17 +223,22 @@ pub fn pre_validate_chunk_state_witness( let main_transition_params = if last_chunk_block.header().is_genesis() { let epoch_id = last_chunk_block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id)?; let congestion_info = last_chunk_block .block_congestion_info() .get(&shard_id) .map(|info| info.congestion_info); let genesis_protocol_version = epoch_manager.get_epoch_protocol_version(&epoch_id)?; - let chunk_extra = - chain.genesis_chunk_extra(shard_id, genesis_protocol_version, congestion_info)?; + let chunk_extra = chain.genesis_chunk_extra( + &shard_layout, + shard_id, + genesis_protocol_version, + congestion_info, + )?; MainTransition::Genesis { chunk_extra, block_hash: *last_chunk_block.hash(), shard_id } } else { MainTransition::NewChunk(NewChunkData { - chunk_header: last_chunk_block.chunks().get(shard_id as usize).unwrap().clone(), + chunk_header: last_chunk_block.chunks().get(shard_index).unwrap().clone(), transactions: state_witness.transactions.clone(), receipts: receipts_to_apply, block: Chain::get_apply_chunk_block_context( @@ -527,7 +539,7 @@ impl Chain { let height_created = witness.chunk_header.height_created(); let chunk_hash = witness.chunk_header.chunk_hash(); let parent_span = tracing::debug_span!( - target: "chain", "shadow_validate", shard_id, height_created); + target: "chain", "shadow_validate", ?shard_id, height_created); let (encoded_witness, raw_witness_size) = { let shard_id_label = shard_id.to_string(); let encode_timer = @@ -554,7 +566,7 @@ impl Chain { pre_validate_chunk_state_witness(&witness, &self, epoch_manager, runtime_adapter)?; tracing::debug!( parent: &parent_span, - shard_id, + ?shard_id, ?chunk_hash, witness_size = encoded_witness.size_bytes(), raw_witness_size, @@ -580,7 +592,7 @@ impl Chain { Ok(()) => { tracing::debug!( parent: &parent_span, - shard_id, + ?shard_id, ?chunk_hash, validation_elapsed = ?validation_start.elapsed(), "completed shadow chunk validation" @@ -592,7 +604,7 @@ impl Chain { tracing::error!( parent: &parent_span, ?err, - shard_id, + ?shard_id, ?chunk_hash, "shadow chunk validation failed" ); diff --git a/chain/chain/src/stateless_validation/state_transition_data.rs b/chain/chain/src/stateless_validation/state_transition_data.rs index 9bf05e23753..f57d0f35c47 100644 --- a/chain/chain/src/stateless_validation/state_transition_data.rs +++ b/chain/chain/src/stateless_validation/state_transition_data.rs @@ -23,8 +23,11 @@ impl Chain { return Ok(()); } let final_block = chain_store.get_block(&final_block_hash)?; - let final_block_chunk_created_heights = - final_block.chunks().iter().map(|chunk| chunk.height_created()).collect::>(); + let final_block_chunk_created_heights = final_block + .chunks() + .iter() + .map(|chunk| (chunk.shard_id(), chunk.height_created())) + .collect::>(); clear_before_last_final_block(chain_store, &final_block_chunk_created_heights)?; Ok(()) } @@ -38,7 +41,7 @@ impl Chain { /// TODO(resharding): this doesn't work after shard layout change fn clear_before_last_final_block( chain_store: &ChainStore, - last_final_block_chunk_created_heights: &[BlockHeight], + last_final_block_chunk_created_heights: &[(ShardId, BlockHeight)], ) -> Result<(), Error> { let mut start_heights = if let Some(start_heights) = chain_store @@ -56,10 +59,7 @@ fn clear_before_last_final_block( "garbage collecting state transition data" ); let mut store_update = chain_store.store().store_update(); - for (shard_index, &last_final_block_height) in - last_final_block_chunk_created_heights.iter().enumerate() - { - let shard_id = shard_index as ShardId; + for &(shard_id, last_final_block_height) in last_final_block_chunk_created_heights.iter() { let start_height = *start_heights.get(&shard_id).unwrap_or(&last_final_block_height); let mut potentially_deleted_count = 0; for height in start_height..last_final_block_height { @@ -72,7 +72,7 @@ fn clear_before_last_final_block( } tracing::debug!( target: "state_transition_data", - shard_id, + ?shard_id, start_height, potentially_deleted_count, "garbage collected state transition data for shard" @@ -116,7 +116,7 @@ mod tests { use near_primitives::block_header::{BlockHeader, BlockHeaderInnerLite, BlockHeaderV4}; use near_primitives::hash::{hash, CryptoHash}; use near_primitives::stateless_validation::stored_chunk_state_transition_data::StoredChunkStateTransitionData; - use near_primitives::types::{BlockHeight, EpochId, ShardId}; + use near_primitives::types::{new_shard_id_tmp, BlockHeight, EpochId, ShardId}; use near_primitives::utils::{get_block_shard_id, get_block_shard_id_rev, index_to_bytes}; use near_store::db::STATE_TRANSITION_START_HEIGHTS; use near_store::test_utils::create_test_store; @@ -127,7 +127,7 @@ mod tests { #[test] fn initial_state_transition_data_gc() { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let block_at_1 = hash(&[1]); let block_at_2 = hash(&[2]); let block_at_3 = hash(&[3]); @@ -136,8 +136,9 @@ mod tests { for (hash, height) in [(block_at_1, 1), (block_at_2, 2), (block_at_3, 3)] { save_state_transition_data(&store, hash, height, shard_id); } - clear_before_last_final_block(&create_chain_store(&store), &[final_height]).unwrap(); - check_start_heights(&store, vec![final_height]); + clear_before_last_final_block(&create_chain_store(&store), &[(shard_id, final_height)]) + .unwrap(); + check_start_heights(&store, vec![(shard_id, final_height)]); check_existing_state_transition_data( &store, vec![(block_at_2, shard_id), (block_at_3, shard_id)], @@ -145,34 +146,27 @@ mod tests { } #[test] fn multiple_state_transition_data_gc() { - let shard_id = 0; + let shard_id = new_shard_id_tmp(0); let store = create_test_store(); let chain_store = create_chain_store(&store); save_state_transition_data(&store, hash(&[1]), 1, shard_id); save_state_transition_data(&store, hash(&[2]), 2, shard_id); - clear_before_last_final_block(&chain_store, &[2]).unwrap(); + clear_before_last_final_block(&chain_store, &[(shard_id, 2)]).unwrap(); let block_at_3 = hash(&[3]); let final_height = 3; save_state_transition_data(&store, block_at_3, final_height, shard_id); - clear_before_last_final_block(&chain_store, &[3]).unwrap(); - check_start_heights(&store, vec![final_height]); + clear_before_last_final_block(&chain_store, &[(shard_id, 3)]).unwrap(); + check_start_heights(&store, vec![(shard_id, final_height)]); check_existing_state_transition_data(&store, vec![(block_at_3, shard_id)]); } #[track_caller] - fn check_start_heights(store: &Store, expected: Vec) { + fn check_start_heights(store: &Store, expected: Vec<(ShardId, BlockHeight)>) { let start_heights = store .get_ser::(DBCol::Misc, STATE_TRANSITION_START_HEIGHTS) .unwrap() .unwrap(); - assert_eq!( - start_heights, - expected - .into_iter() - .enumerate() - .map(|(i, h)| (i as ShardId, h)) - .collect::>() - ); + assert_eq!(start_heights, expected.into_iter().collect::>()); } #[track_caller] diff --git a/chain/chain/src/store/latest_witnesses.rs b/chain/chain/src/store/latest_witnesses.rs index 13020e86145..2b61626b749 100644 --- a/chain/chain/src/store/latest_witnesses.rs +++ b/chain/chain/src/store/latest_witnesses.rs @@ -113,7 +113,7 @@ impl ChainStore { target: "client", "save_latest_chunk_state_witness", witness_height = witness.chunk_header.height_created(), - witness_shard = witness.chunk_header.shard_id(), + witness_shard = ?witness.chunk_header.shard_id(), ) .entered(); @@ -173,7 +173,7 @@ impl ChainStore { OsRng.fill_bytes(&mut random_uuid); let key = LatestWitnessesKey { height: witness.chunk_header.height_created(), - shard_id: witness.chunk_header.shard_id(), + shard_id: witness.chunk_header.shard_id().into(), epoch_id: witness.epoch_id, witness_size: serialized_witness_size, random_uuid, diff --git a/chain/chain/src/store/mod.rs b/chain/chain/src/store/mod.rs index 877b0c1f375..9ccfbf373ee 100644 --- a/chain/chain/src/store/mod.rs +++ b/chain/chain/src/store/mod.rs @@ -243,8 +243,8 @@ pub trait ChainStoreAccess { target: "chain", version = shard_layout.version(), prev_version = prev_shard_layout.version(), - shard_id, - parent_shard_id, + ?shard_id, + ?parent_shard_id, "crossing epoch boundary with shard layout change, updating shard id" ); shard_id = parent_shard_id; @@ -349,18 +349,22 @@ pub trait ChainStoreAccess { shard_id: ShardId, ) -> Result { let mut candidate_hash = *hash; + let block_header = self.get_block_header(&candidate_hash)?; + let shard_layout = epoch_manager.get_shard_layout(block_header.epoch_id())?; let mut shard_id = shard_id; + let mut shard_index = shard_layout.get_shard_index(shard_id); loop { let block_header = self.get_block_header(&candidate_hash)?; if *block_header .chunk_mask() - .get(shard_id as usize) - .ok_or_else(|| Error::InvalidShardId(shard_id as ShardId))? + .get(shard_index) + .ok_or_else(|| Error::InvalidShardId(shard_id))? { break Ok(*block_header.epoch_id()); } candidate_hash = *block_header.prev_hash(); - shard_id = epoch_manager.get_prev_shard_ids(&candidate_hash, vec![shard_id])?[0]; + (shard_id, shard_index) = + epoch_manager.get_prev_shard_ids(&candidate_hash, vec![shard_id])?[0]; } } } @@ -370,7 +374,7 @@ pub trait ChainStoreAccess { /// incoming receipts and the shard layout changed. fn filter_incoming_receipts_for_shard( target_shard_layout: &ShardLayout, - target_shard_id: u64, + target_shard_id: ShardId, receipt_proofs: Arc>, ) -> Vec { let mut filtered_receipt_proofs = vec![]; @@ -586,10 +590,10 @@ impl ChainStore { receipts: &mut Vec, protocol_version: ProtocolVersion, shard_layout: &ShardLayout, - shard_id: u64, - receipts_shard_id: u64, + shard_id: ShardId, + receipts_shard_id: ShardId, ) -> Result<(), Error> { - tracing::trace!(target: "resharding", ?protocol_version, shard_id, receipts_shard_id, "reassign_outgoing_receipts_for_resharding"); + tracing::trace!(target: "resharding", ?protocol_version, ?shard_id, ?receipts_shard_id, "reassign_outgoing_receipts_for_resharding"); // If simple nightshade v2 is enabled and stable use that. // Same reassignment of outgoing receipts works for simple nightshade v3 if checked_feature!("stable", SimpleNightshadeV2, protocol_version) { @@ -2173,9 +2177,9 @@ impl<'a> ChainStoreUpdate<'a> { source_store.get_chunk_extra(block_hash, &shard_uid)?.clone(), ); } - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let chunk_hash = chunk_header.chunk_hash(); - let shard_id = shard_id as u64; chain_store_update .chain_store_cache_update .chunks diff --git a/chain/chain/src/store_validator/validate.rs b/chain/chain/src/store_validator/validate.rs index 743a293e499..61163f7f10c 100644 --- a/chain/chain/src/store_validator/validate.rs +++ b/chain/chain/src/store_validator/validate.rs @@ -578,8 +578,9 @@ pub(crate) fn trie_changes_chunk_extra_exists( // 5. There should be ShardChunk with ShardId `shard_id` let shard_id = shard_uid.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let chunks = block.chunks(); - if let Some(chunk_header) = chunks.get(shard_id as usize) { + if let Some(chunk_header) = chunks.get(shard_index) { // if the chunk is not a new chunk, skip the check if chunk_header.height_included() != block.header().height() { return Ok(()); diff --git a/chain/chain/src/test_utils/kv_runtime.rs b/chain/chain/src/test_utils/kv_runtime.rs index cc2e76bdca6..a4ed03ac276 100644 --- a/chain/chain/src/test_utils/kv_runtime.rs +++ b/chain/chain/src/test_utils/kv_runtime.rs @@ -42,8 +42,8 @@ use near_primitives::transaction::{ }; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ - AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, Nonce, NumShards, - ShardId, StateRoot, StateRootNode, ValidatorInfoIdentifier, + shard_id_as_u32, AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, Nonce, + NumShards, ShardId, ShardIndex, StateRoot, StateRootNode, ValidatorInfoIdentifier, }; use near_primitives::version::{ProtocolFeature, ProtocolVersion, PROTOCOL_VERSION}; use near_primitives::views::{ @@ -170,14 +170,15 @@ impl MockEpochManager { }) .collect(); - let validators_per_shard = block_producers.len() as ShardId / vs.validator_groups; - let coef = block_producers.len() as ShardId / vs.num_shards; + let validators_per_shard = block_producers.len() / vs.validator_groups as usize; + let coef = block_producers.len() / vs.num_shards as usize; let chunk_producers: Vec> = (0..vs.num_shards) - .map(|shard_id| { - let offset = (shard_id * coef / validators_per_shard * validators_per_shard) - as usize; - block_producers[offset..offset + validators_per_shard as usize].to_vec() + .map(|shard_index| { + let shard_index = shard_index as usize; + let offset = + shard_index * coef / validators_per_shard * validators_per_shard; + block_producers[offset..offset + validators_per_shard].to_vec() }) .collect(); @@ -289,8 +290,8 @@ impl MockEpochManager { &self.validators_by_valset[valset].block_producers } - fn get_chunk_producers(&self, valset: usize, shard_id: ShardId) -> Vec { - self.validators_by_valset[valset].chunk_producers[shard_id as usize].clone() + fn get_chunk_producers(&self, valset: usize, shard_index: ShardIndex) -> Vec { + self.validators_by_valset[valset].chunk_producers[shard_index].clone() } fn get_valset_for_epoch(&self, epoch_id: &EpochId) -> Result { @@ -423,8 +424,8 @@ impl EpochManagerAdapter for MockEpochManager { self.hash_to_valset.write().unwrap().contains_key(epoch_id) } - fn shard_ids(&self, _epoch_id: &EpochId) -> Result, EpochError> { - Ok((0..self.num_shards).collect()) + fn shard_ids(&self, epoch_id: &EpochId) -> Result, EpochError> { + Ok(self.get_shard_layout(epoch_id)?.shard_ids().collect()) } fn num_total_parts(&self) -> usize { @@ -463,7 +464,7 @@ impl EpochManagerAdapter for MockEpochManager { shard_id: ShardId, _epoch_id: &EpochId, ) -> Result { - Ok(ShardUId { version: 0, shard_id: shard_id as u32 }) + Ok(ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }) } fn get_block_info(&self, _hash: &CryptoHash) -> Result, EpochError> { @@ -587,18 +588,33 @@ impl EpochManagerAdapter for MockEpochManager { fn get_prev_shard_ids( &self, - _prev_hash: &CryptoHash, + prev_hash: &CryptoHash, shard_ids: Vec, - ) -> Result, Error> { - Ok(shard_ids) + ) -> Result, Error> { + let mut prev_shard_ids = vec![]; + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; + for shard_id in shard_ids { + // This is not correct if there was a resharding event in between + // the previous and current block. + let prev_shard_id = shard_id; + let prev_shard_index = shard_layout.get_shard_index(prev_shard_id); + prev_shard_ids.push((prev_shard_id, prev_shard_index)); + } + + Ok(prev_shard_ids) } fn get_prev_shard_id( &self, - _prev_hash: &CryptoHash, + prev_hash: &CryptoHash, shard_id: ShardId, - ) -> Result { - Ok(shard_id) + ) -> Result<(ShardId, ShardIndex), Error> { + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; + // This is not correct if there was a resharding event in between + // the previous and current block. + let prev_shard_id = shard_id; + let prev_shard_index = shard_layout.get_shard_index(prev_shard_id); + Ok((prev_shard_id, prev_shard_index)) } fn get_shard_layout_from_prev_block( @@ -728,8 +744,10 @@ impl EpochManagerAdapter for MockEpochManager { shard_id: ShardId, ) -> Result { let valset = self.get_valset_for_epoch(epoch_id)?; - let chunk_producers = self.get_chunk_producers(valset, shard_id); - let index = (shard_id + height + 1) as usize % chunk_producers.len(); + let shard_layout = self.get_shard_layout(epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers(valset, shard_index); + let index = (shard_index + height as usize + 1) % chunk_producers.len(); Ok(chunk_producers[index].account_id().clone()) } @@ -977,7 +995,9 @@ impl EpochManagerAdapter for MockEpochManager { // we check if we care about a shard. Please do not remove the unwrap, fix the logic of // the calling function. let epoch_valset = self.get_valset_for_epoch(&epoch_id).unwrap(); - let chunk_producers = self.get_chunk_producers(epoch_valset, shard_id); + let shard_layout = self.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers(epoch_valset, shard_index); for validator in chunk_producers { if validator.account_id() == account_id { return Ok(true); @@ -996,7 +1016,9 @@ impl EpochManagerAdapter for MockEpochManager { // we check if we care about a shard. Please do not remove the unwrap, fix the logic of // the calling function. let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap(); - let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_id); + let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_index); for validator in chunk_producers { if validator.account_id() == account_id { return Ok(true); @@ -1015,8 +1037,12 @@ impl EpochManagerAdapter for MockEpochManager { // we check if we care about a shard. Please do not remove the unwrap, fix the logic of // the calling function. let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap(); - let chunk_producers = self - .get_chunk_producers((epoch_valset.1 + 1) % self.validators_by_valset.len(), shard_id); + let shard_layout = self.get_shard_layout_from_prev_block(parent_hash)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers = self.get_chunk_producers( + (epoch_valset.1 + 1) % self.validators_by_valset.len(), + shard_index, + ); for validator in chunk_producers { if validator.account_id() == account_id { return Ok(true); @@ -1077,9 +1103,10 @@ impl RuntimeAdapter for KeyValueRuntime { state_root: StateRoot, _use_flat_storage: bool, ) -> Result { - Ok(self - .tries - .get_trie_for_shard(ShardUId { version: 0, shard_id: shard_id as u32 }, state_root)) + Ok(self.tries.get_trie_for_shard( + ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }, + state_root, + )) } fn get_flat_storage_manager(&self) -> near_store::flat::FlatStorageManager { @@ -1093,7 +1120,7 @@ impl RuntimeAdapter for KeyValueRuntime { state_root: StateRoot, ) -> Result { Ok(self.tries.get_view_trie_for_shard( - ShardUId { version: 0, shard_id: shard_id as u32 }, + ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }, state_root, )) } @@ -1278,7 +1305,7 @@ impl RuntimeAdapter for KeyValueRuntime { Ok(ApplyChunkResult { trie_changes: WrappedTrieChanges::new( self.get_tries(), - ShardUId { version: 0, shard_id: shard_id as u32 }, + ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }, TrieChanges::empty(state_root), Default::default(), block.block_hash, diff --git a/chain/chain/src/types.rs b/chain/chain/src/types.rs index 7f18ea65c61..1063ec140d2 100644 --- a/chain/chain/src/types.rs +++ b/chain/chain/src/types.rs @@ -540,6 +540,7 @@ mod tests { use near_primitives::merkle::verify_path; use near_primitives::test_utils::{create_test_signer, TestBlockBuilder}; use near_primitives::transaction::{ExecutionMetadata, ExecutionOutcome, ExecutionStatus}; + use near_primitives::types::new_shard_id_tmp; use near_primitives::version::PROTOCOL_VERSION; use std::sync::Arc; @@ -547,7 +548,7 @@ mod tests { #[test] fn test_block_produce() { - let shard_ids: Vec<_> = (0..32).collect(); + let shard_ids: Vec<_> = (0..32).map(new_shard_id_tmp).collect(); let genesis_chunks = genesis_chunks( vec![Trie::EMPTY_ROOT], vec![Default::default(); shard_ids.len()], diff --git a/chain/chain/src/update_shard.rs b/chain/chain/src/update_shard.rs index b1d760c9547..5f171bcc7db 100644 --- a/chain/chain/src/update_shard.rs +++ b/chain/chain/src/update_shard.rs @@ -133,7 +133,7 @@ pub fn apply_new_chunk( target: "chain", parent: parent_span, "apply_new_chunk", - shard_id, + ?shard_id, ?apply_reason) .entered(); let gas_limit = chunk_header.gas_limit(); @@ -182,7 +182,7 @@ pub fn apply_old_chunk( target: "chain", parent: parent_span, "apply_old_chunk", - shard_id, + ?shard_id, ?apply_reason) .entered(); diff --git a/chain/chunks/src/chunk_cache.rs b/chain/chunks/src/chunk_cache.rs index 2ac1f0b0e70..4c830378857 100644 --- a/chain/chunks/src/chunk_cache.rs +++ b/chain/chunks/src/chunk_cache.rs @@ -274,12 +274,13 @@ mod tests { use near_crypto::KeyType; use near_primitives::hash::CryptoHash; use near_primitives::sharding::{PartialEncodedChunkV2, ShardChunkHeader, ShardChunkHeaderV2}; + use near_primitives::types::{new_shard_id_tmp, ShardId}; use near_primitives::validator_signer::InMemoryValidatorSigner; use crate::chunk_cache::EncodedChunksCache; use crate::shards_manager_actor::ChunkRequestInfo; - fn create_chunk_header(height: u64, shard_id: u64) -> ShardChunkHeader { + fn create_chunk_header(height: u64, shard_id: ShardId) -> ShardChunkHeader { let signer = InMemoryValidatorSigner::from_random("test".parse().unwrap(), KeyType::ED25519); ShardChunkHeader::V2(ShardChunkHeaderV2::new( @@ -303,8 +304,8 @@ mod tests { #[test] fn test_incomplete_chunks() { let mut cache = EncodedChunksCache::new(); - let header0 = create_chunk_header(1, 0); - let header1 = create_chunk_header(1, 1); + let header0 = create_chunk_header(1, new_shard_id_tmp(0)); + let header1 = create_chunk_header(1, new_shard_id_tmp(1)); cache.get_or_insert_from_header(&header0); cache.merge_in_partial_encoded_chunk(&PartialEncodedChunkV2 { header: header1.clone(), @@ -327,7 +328,7 @@ mod tests { #[test] fn test_cache_removal() { let mut cache = EncodedChunksCache::new(); - let header = create_chunk_header(1, 0); + let header = create_chunk_header(1, new_shard_id_tmp(0)); let partial_encoded_chunk = PartialEncodedChunkV2 { header: header, parts: vec![], prev_outgoing_receipts: vec![] }; cache.merge_in_partial_encoded_chunk(&partial_encoded_chunk); diff --git a/chain/chunks/src/client.rs b/chain/chunks/src/client.rs index 8e00394519f..a898bacabdd 100644 --- a/chain/chunks/src/client.rs +++ b/chain/chunks/src/client.rs @@ -161,7 +161,7 @@ mod tests { hash::CryptoHash, shard_layout::{account_id_to_shard_uid, ShardLayout}, transaction::SignedTransaction, - types::AccountId, + types::{new_shard_id_tmp, shard_id_as_u32, AccountId, ShardId}, }; use near_store::ShardUId; use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; @@ -171,11 +171,12 @@ mod tests { #[test] fn test_random_seed_with_shard_id() { - let seed0 = ShardedTransactionPool::random_seed(&TEST_SEED, 0); - let seed10 = ShardedTransactionPool::random_seed(&TEST_SEED, 10); - let seed256 = ShardedTransactionPool::random_seed(&TEST_SEED, 256); - let seed1000 = ShardedTransactionPool::random_seed(&TEST_SEED, 1000); - let seed1000000 = ShardedTransactionPool::random_seed(&TEST_SEED, 1_000_000); + let seed0 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(0)); + let seed10 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(10)); + let seed256 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(256)); + let seed1000 = ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(1000)); + let seed1000000 = + ShardedTransactionPool::random_seed(&TEST_SEED, new_shard_id_tmp(1_000_000)); assert_ne!(seed0, seed10); assert_ne!(seed0, seed256); assert_ne!(seed0, seed1000); @@ -196,12 +197,13 @@ mod tests { let mut pool = ShardedTransactionPool::new(TEST_SEED, None); - let mut shard_id_to_accounts = HashMap::new(); - shard_id_to_accounts.insert(0, vec!["aaa", "abcd", "a-a-a-a-a"]); - shard_id_to_accounts.insert(1, vec!["aurora"]); - shard_id_to_accounts.insert(2, vec!["aurora-0", "bob", "kkk"]); + let mut shard_id_to_accounts: HashMap = HashMap::new(); + shard_id_to_accounts.insert(new_shard_id_tmp(0), vec!["aaa", "abcd", "a-a-a-a-a"]); + shard_id_to_accounts.insert(new_shard_id_tmp(1), vec!["aurora"]); + shard_id_to_accounts.insert(new_shard_id_tmp(2), vec!["aurora-0", "bob", "kkk"]); // this shard is split, make sure there are accounts for both shards 3' and 4' - shard_id_to_accounts.insert(3, vec!["mmm", "rrr", "sweat", "ttt", "www", "zzz"]); + shard_id_to_accounts + .insert(new_shard_id_tmp(3), vec!["mmm", "rrr", "sweat", "ttt", "www", "zzz"]); let deposit = 222; @@ -234,8 +236,10 @@ mod tests { CryptoHash::default(), ); - let shard_uid = - ShardUId { shard_id: signer_shard_id as u32, version: old_shard_layout.version() }; + let shard_uid = ShardUId { + shard_id: shard_id_as_u32(signer_shard_id), + version: old_shard_layout.version(), + }; pool.insert_transaction(shard_uid, tx); } @@ -250,7 +254,7 @@ mod tests { { let shard_ids: Vec<_> = new_shard_layout.shard_ids().collect(); for &shard_id in shard_ids.iter() { - let shard_id = shard_id as u32; + let shard_id = shard_id_as_u32(shard_id); let shard_uid = ShardUId { shard_id, version: new_shard_layout.version() }; let pool = pool.pool_for_shard(shard_uid); let pool_len = pool.len(); @@ -260,7 +264,7 @@ mod tests { let mut total = 0; for shard_id in shard_ids { - let shard_id = shard_id as u32; + let shard_id = shard_id_as_u32(shard_id); let shard_uid = ShardUId { shard_id, version: new_shard_layout.version() }; let mut pool_iter = pool.get_pool_iterator(shard_uid).unwrap(); while let Some(group) = pool_iter.next() { diff --git a/chain/chunks/src/logic.rs b/chain/chunks/src/logic.rs index 5f483094501..daf09dcfa24 100644 --- a/chain/chunks/src/logic.rs +++ b/chain/chunks/src/logic.rs @@ -110,8 +110,8 @@ pub fn make_outgoing_receipts_proofs( let mut receipts_by_shard = Chain::group_receipts_by_shard(outgoing_receipts.to_vec(), &shard_layout); - let it = proofs.into_iter().enumerate().map(move |(proof_shard_id, proof)| { - let proof_shard_id = proof_shard_id as u64; + let it = proofs.into_iter().enumerate().map(move |(proof_shard_index, proof)| { + let proof_shard_id = shard_layout.get_shard_id(proof_shard_index); let receipts = receipts_by_shard.remove(&proof_shard_id).unwrap_or_else(Vec::new); let shard_proof = ShardProof { from_shard_id: shard_id, to_shard_id: proof_shard_id, proof }; @@ -174,7 +174,7 @@ pub fn decode_encoded_chunk( target: "chunks", "decode_encoded_chunk", height_included = encoded_chunk.cloned_header().height_included(), - shard_id = encoded_chunk.cloned_header().shard_id(), + shard_id = ?encoded_chunk.cloned_header().shard_id(), ?chunk_hash) .entered(); diff --git a/chain/chunks/src/shards_manager_actor.rs b/chain/chunks/src/shards_manager_actor.rs index 6829b9f3ee2..700261f9a06 100644 --- a/chain/chunks/src/shards_manager_actor.rs +++ b/chain/chunks/src/shards_manager_actor.rs @@ -524,7 +524,7 @@ impl ShardsManagerActor { debug!( target: "chunks", ?part_ords, - shard_id, + ?shard_id, ?target_account, prefer_peer, "Requesting parts", @@ -684,18 +684,18 @@ impl ShardsManagerActor { target: "chunks", "request_chunk_single", ?chunk_hash, - shard_id, + ?shard_id, height_created = height) .entered(); if self.requested_partial_encoded_chunks.contains_key(&chunk_hash) { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Not requesting chunk, already being requested."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Not requesting chunk, already being requested."); return; } if let Some(entry) = self.encoded_chunks.get(&chunk_header.chunk_hash()) { if entry.complete { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Not requesting chunk, already complete."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Not requesting chunk, already complete."); return; } } else { @@ -703,7 +703,7 @@ impl ShardsManagerActor { // However, if the chunk had just been processed and marked as complete, it might have // been removed from the cache if it is out of horizon. So in this case, the chunk is // already complete and we don't need to request anything. - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Not requesting chunk, already complete and GC-ed."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Not requesting chunk, already complete and GC-ed."); return; } @@ -721,7 +721,7 @@ impl ShardsManagerActor { ); if mark_only { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Marked the chunk as being requested but did not send the request yet."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Marked the chunk as being requested but did not send the request yet."); return; } @@ -749,7 +749,7 @@ impl ShardsManagerActor { // we want to give some time for any `PartialEncodedChunkForward` messages to arrive // before we send requests. if !should_wait_for_chunk_forwarding || fetch_from_archival || old_block { - debug!(target: "chunks", height, shard_id, ?chunk_hash, "Requesting."); + debug!(target: "chunks", height, ?shard_id, ?chunk_hash, "Requesting."); let request_result = self.request_partial_encoded_chunk( height, &ancestor_hash, @@ -1108,7 +1108,7 @@ impl ShardsManagerActor { target: "chunks", "check_chunk_complete", height_included = chunk.cloned_header().height_included(), - shard_id = chunk.cloned_header().shard_id(), + shard_id = ?chunk.cloned_header().shard_id(), chunk_hash = ?chunk.chunk_hash()) .entered(); @@ -1471,7 +1471,7 @@ impl ShardsManagerActor { target: "chunks", "process_partial_encoded_chunk", ?chunk_hash, - shard_id = header.shard_id(), + shard_id = ?header.shard_id(), height_created = header.height_created(), height_included = header.height_included()) .entered(); @@ -2263,7 +2263,7 @@ mod test { use near_network::types::NetworkRequests; use near_primitives::block::Tip; use near_primitives::hash::{hash, CryptoHash}; - use near_primitives::types::EpochId; + use near_primitives::types::{new_shard_id_tmp, EpochId}; use near_primitives::validator_signer::EmptyValidatorSigner; use near_store::test_utils::create_test_store; use std::sync::Arc; @@ -2322,7 +2322,7 @@ mod test { height: 0, ancestor_hash: Default::default(), prev_block_hash: Default::default(), - shard_id: 0, + shard_id: new_shard_id_tmp(0), added, last_requested: added, }, diff --git a/chain/chunks/src/test_utils.rs b/chain/chunks/src/test_utils.rs index 4415b467a6d..8d659afa4c7 100644 --- a/chain/chunks/src/test_utils.rs +++ b/chain/chunks/src/test_utils.rs @@ -15,7 +15,7 @@ use near_primitives::sharding::{ ShardChunkHeader, }; use near_primitives::test_utils::create_test_signer; -use near_primitives::types::MerkleHash; +use near_primitives::types::{new_shard_id_tmp, MerkleHash}; use near_primitives::types::{AccountId, EpochId, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::adapter::chunk_store::ChunkStoreAdapter; @@ -92,7 +92,7 @@ impl ChunkTestFixture { let (mock_parent_hash, mock_height) = if orphan_chunk { (CryptoHash::hash_bytes(&[]), 2) } else { (mock_ancestor_hash, 1) }; // setting this to 2 instead of 0 so that when chunk producers - let mock_shard_id: ShardId = 0; + let mock_shard_id: ShardId = new_shard_id_tmp(0); let mock_epoch_id = epoch_manager.get_epoch_id_from_prev_block(&mock_ancestor_hash).unwrap(); let mock_chunk_producer = diff --git a/chain/client-primitives/src/debug.rs b/chain/client-primitives/src/debug.rs index aacf1128e06..29b7bd314cb 100644 --- a/chain/client-primitives/src/debug.rs +++ b/chain/client-primitives/src/debug.rs @@ -2,7 +2,7 @@ //! without backwards compatibility of JSON encoding. use crate::types::StatusError; use near_primitives::congestion_info::CongestionInfo; -use near_primitives::types::EpochId; +use near_primitives::types::{EpochId, ShardId}; use near_primitives::views::{ CatchupStatusView, ChainProcessingInfo, EpochValidatorInfo, RequestedStatePartsView, SyncStatusView, @@ -143,7 +143,7 @@ pub struct ProductionAtHeight { // None if we are not responsible for producing this block. pub block_production: Option, // Map from shard_id to chunk that we are responsible to produce at this height - pub chunk_production: HashMap, + pub chunk_production: HashMap, } // Information about the approvals that we received. diff --git a/chain/client-primitives/src/types.rs b/chain/client-primitives/src/types.rs index 2baeb99161d..b2d97aacd58 100644 --- a/chain/client-primitives/src/types.rs +++ b/chain/client-primitives/src/types.rs @@ -479,7 +479,7 @@ pub enum GetChunkError { #[error("Block either has never been observed on the node or has been garbage collected: {error_message}")] UnknownBlock { error_message: String }, #[error("Shard ID {shard_id} is invalid")] - InvalidShardId { shard_id: u64 }, + InvalidShardId { shard_id: ShardId }, #[error("Chunk with hash {chunk_hash:?} has never been observed on this node")] UnknownChunk { chunk_hash: ChunkHash }, // NOTE: Currently, the underlying errors are too broad, and while we tried to handle diff --git a/chain/client/src/chunk_distribution_network.rs b/chain/client/src/chunk_distribution_network.rs index b66b0dc2dce..e27fa436ab8 100644 --- a/chain/client/src/chunk_distribution_network.rs +++ b/chain/client/src/chunk_distribution_network.rs @@ -226,6 +226,7 @@ mod tests { PartialEncodedChunkV2, ShardChunkHeaderInner, ShardChunkHeaderInnerV3, ShardChunkHeaderV3, }, + types::new_shard_id_tmp, validator_signer::EmptyValidatorSigner, }; use std::{collections::HashMap, convert::Infallible, future::Future}; @@ -235,7 +236,7 @@ mod tests { fn test_request_chunks() { let (mock_sender, mut message_receiver) = mpsc::unbounded_channel(); let mut client = MockClient::default(); - let missing_chunk = mock_shard_chunk(0, 0); + let missing_chunk = mock_shard_chunk(0, 0u64.into()); let mut blocks_delay_tracker = BlocksDelayTracker::new(Clock::real()); let shards_manager = MockSender::new(mock_sender); let shards_manager_adapter = shards_manager.into_sender(); @@ -309,8 +310,8 @@ mod tests { // When chunks are known by the client, the shards manager // is told to process the chunk directly - let known_chunk_1 = mock_shard_chunk(1, 0); - let known_chunk_2 = mock_shard_chunk(2, 0); + let known_chunk_1 = mock_shard_chunk(1, new_shard_id_tmp(0)); + let known_chunk_2 = mock_shard_chunk(2, new_shard_id_tmp(0)); client.publish_chunk(&known_chunk_1).now_or_never(); client.publish_chunk(&known_chunk_2).now_or_never(); let blocks_missing_chunks = vec![BlockMissingChunks { @@ -392,7 +393,7 @@ mod tests { }); } - fn mock_shard_chunk(height: u64, shard_id: u64) -> PartialEncodedChunk { + fn mock_shard_chunk(height: u64, shard_id: ShardId) -> PartialEncodedChunk { let prev_block_hash = hash(&[height.to_le_bytes().as_slice(), shard_id.to_le_bytes().as_slice()].concat()); let mut mock_hashes = MockHashes::new(prev_block_hash); diff --git a/chain/client/src/client.rs b/chain/client/src/client.rs index 9a13abeba08..6bbff7d0994 100644 --- a/chain/client/src/client.rs +++ b/chain/client/src/client.rs @@ -150,7 +150,7 @@ pub struct Client { /// A mapping from a block for which a state sync is underway for the next epoch, and the object /// storing the current status of the state sync and blocks catch up pub catchup_state_syncs: - HashMap, BlocksCatchUpState)>, + HashMap, BlocksCatchUpState)>, /// Keeps track of information needed to perform the initial Epoch Sync pub epoch_sync: EpochSync, /// Keeps track of syncing headers. @@ -428,8 +428,9 @@ impl Client { block: &Block, ) -> Result<(), Error> { let epoch_id = self.epoch_manager.get_epoch_id(block.hash())?; - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { - let shard_id = shard_id as ShardId; + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id)?; if block.header().height() == chunk_header.height_included() { if cares_about_shard_this_or_next_epoch( @@ -458,8 +459,10 @@ impl Client { block: &Block, ) -> Result<(), Error> { let epoch_id = self.epoch_manager.get_epoch_id(block.hash())?; - for (shard_id, chunk_header) in block.chunks().iter().enumerate() { - let shard_id = shard_id as ShardId; + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + + for (shard_index, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id)?; if block.header().height() == chunk_header.height_included() { @@ -726,7 +729,7 @@ impl Client { BlockProductionTracker::construct_chunk_collection_info( height, &epoch_id, - chunk_headers.len() as ShardId, + chunk_headers.len(), &new_chunks, self.epoch_manager.as_ref(), &self.chunk_inclusion_tracker, @@ -734,16 +737,18 @@ impl Client { ); // Collect new chunk headers and endorsements. + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; for (shard_id, chunk_hash) in new_chunks { + let shard_index = shard_layout.get_shard_index(shard_id); let (mut chunk_header, chunk_endorsement) = self.chunk_inclusion_tracker.get_chunk_header_and_endorsements(&chunk_hash)?; *chunk_header.height_included_mut() = height; *chunk_headers - .get_mut(shard_id as usize) + .get_mut(shard_index) .ok_or_else(|| near_chain_primitives::Error::InvalidShardId(shard_id))? = chunk_header; *chunk_endorsements - .get_mut(shard_id as usize) + .get_mut(shard_index) .ok_or_else(|| near_chain_primitives::Error::InvalidShardId(shard_id))? = chunk_endorsement; } @@ -831,7 +836,7 @@ impl Client { me = ?signer.as_ref().validator_id(), ?chunk_proposer, next_height, - shard_id, + ?shard_id, "Not producing chunk. Not chunk producer for next chunk."); return Ok(None); } @@ -863,7 +868,7 @@ impl Client { let prev_prev_hash = *self.chain.get_block_header(&prev_block_hash)?.prev_hash(); if !self.chain.prev_block_is_caught_up(&prev_prev_hash, &prev_block_hash)? { // See comment in similar snipped in `produce_block` - debug!(target: "client", shard_id, next_height, "Produce chunk: prev block is not caught up"); + debug!(target: "client", ?shard_id, next_height, "Produce chunk: prev block is not caught up"); return Err(Error::ChunkProducer( "State for the epoch is not downloaded yet, skipping chunk production" .to_string(), @@ -871,7 +876,7 @@ impl Client { } } - debug!(target: "client", me = ?validator_signer.validator_id(), next_height, shard_id, "Producing chunk"); + debug!(target: "client", me = ?validator_signer.validator_id(), next_height, ?shard_id, "Producing chunk"); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, epoch_id)?; let chunk_extra = self @@ -879,9 +884,10 @@ impl Client { .get_chunk_extra(&prev_block_hash, &shard_uid) .map_err(|err| Error::ChunkProducer(format!("No chunk extra available: {}", err)))?; - let prev_shard_id = self.epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; + let (prev_shard_id, prev_shard_index) = + self.epoch_manager.get_prev_shard_id(prev_block.hash(), shard_id)?; let last_chunk_header = - prev_block.chunks().get(prev_shard_id as usize).cloned().ok_or_else(|| { + prev_block.chunks().get(prev_shard_index).cloned().ok_or_else(|| { Error::ChunkProducer(format!( "No last chunk in prev_block_hash {:?}, prev_shard_id: {}", prev_block_hash, prev_shard_id @@ -1022,7 +1028,7 @@ impl Client { chunk_extra: &ChunkExtra, ) -> Result { let Self { chain, sharded_tx_pool, runtime_adapter: runtime, .. } = self; - let shard_id = shard_uid.shard_id as ShardId; + let shard_id = shard_uid.shard_id(); let prepared_transactions = if let Some(mut iter) = sharded_tx_pool.get_pool_iterator(shard_uid) { @@ -1453,8 +1459,18 @@ impl Client { ) { let chunk_header = partial_chunk.cloned_header(); self.chain.blocks_delay_tracker.mark_chunk_completed(&chunk_header); + + // TODO(#10569) We would like a proper error handling here instead of `expect`. + let parent_hash = *chunk_header.prev_block_hash(); + let shard_layout = self + .epoch_manager + .get_shard_layout_from_prev_block(&parent_hash) + .expect("Could not obtain shard layout"); + + let shard_id = partial_chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); self.block_production_info - .record_chunk_collected(partial_chunk.height_created(), partial_chunk.shard_id()); + .record_chunk_collected(partial_chunk.height_created(), shard_index); // TODO(#10569) We would like a proper error handling here instead of `expect`. persist_chunk(partial_chunk, shard_chunk, self.chain.mut_chain_store()) @@ -2210,7 +2226,7 @@ impl Client { validators.remove(account_id); } for validator in validators { - trace!(target: "client", me = ?signer.as_ref().map(|bp| bp.validator_id()), ?tx, ?validator, shard_id, "Routing a transaction"); + trace!(target: "client", me = ?signer.as_ref().map(|bp| bp.validator_id()), ?tx, ?validator, ?shard_id, "Routing a transaction"); // Send message to network to actually forward transaction. self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests( @@ -2392,7 +2408,7 @@ impl Client { // forward to current epoch validators, // possibly forward to next epoch validators if self.active_validator(shard_id, signer)? { - trace!(target: "client", account = ?me, shard_id, tx_hash = ?tx.get_hash(), is_forwarded, "Recording a transaction."); + trace!(target: "client", account = ?me, ?shard_id, tx_hash = ?tx.get_hash(), is_forwarded, "Recording a transaction."); metrics::TRANSACTION_RECEIVED_VALIDATOR.inc(); if !is_forwarded { @@ -2400,12 +2416,12 @@ impl Client { } Ok(ProcessTxResponse::ValidTx) } else if !is_forwarded { - trace!(target: "client", shard_id, tx_hash = ?tx.get_hash(), "Forwarding a transaction."); + trace!(target: "client", ?shard_id, tx_hash = ?tx.get_hash(), "Forwarding a transaction."); metrics::TRANSACTION_RECEIVED_NON_VALIDATOR.inc(); self.forward_tx(&epoch_id, tx, signer)?; Ok(ProcessTxResponse::RequestRouted) } else { - trace!(target: "client", shard_id, tx_hash = ?tx.get_hash(), "Non-validator received a forwarded transaction, dropping it."); + trace!(target: "client", ?shard_id, tx_hash = ?tx.get_hash(), "Non-validator received a forwarded transaction, dropping it."); metrics::TRANSACTION_RECEIVED_NON_VALIDATOR_FORWARDED.inc(); Ok(ProcessTxResponse::NoResponse) } @@ -2414,7 +2430,7 @@ impl Client { Ok(ProcessTxResponse::DoesNotTrackShard) } else if is_forwarded { // Received forwarded transaction but we are not tracking the shard - debug!(target: "client", ?me, shard_id, tx_hash = ?tx.get_hash(), "Received forwarded transaction but no tracking shard"); + debug!(target: "client", ?me, ?shard_id, tx_hash = ?tx.get_hash(), "Received forwarded transaction but no tracking shard"); Ok(ProcessTxResponse::NoResponse) } else { // We are not tracking this shard, so there is no way to validate this tx. Just rerouting. @@ -2497,7 +2513,7 @@ impl Client { debug!(target: "catchup", ?me, ?sync_hash, progress_per_shard = ?format_shard_sync_phase_per_shard(&shards_to_split, false), "Catchup"); let use_colour = matches!(self.config.log_summary_style, LogSummaryStyle::Colored); - let tracking_shards: Vec = + let tracking_shards: Vec = state_sync_info.shards.iter().map(|tuple| tuple.0).collect(); // Notify each shard to sync. if notify_state_sync { @@ -2583,7 +2599,7 @@ impl Client { sync_hash: CryptoHash, state_sync_info: &StateSyncInfo, me: &Option, - ) -> Result, Error> { + ) -> Result, Error> { let prev_hash = *self.chain.get_block(&sync_hash)?.header().prev_hash(); let need_to_reshard = self.epoch_manager.will_shard_layout_change(&prev_hash)?; @@ -2596,7 +2612,7 @@ impl Client { let shards_to_split = state_sync_info .shards .iter() - .filter_map(|ShardInfo(shard_id, _)| self.should_split_shard(shard_id, me, prev_hash)) + .filter_map(|ShardInfo(shard_id, _)| self.should_split_shard(*shard_id, me, prev_hash)) .collect(); Ok(shards_to_split) } @@ -2605,11 +2621,10 @@ impl Client { /// track it. fn should_split_shard( &mut self, - shard_id: &u64, + shard_id: ShardId, me: &Option, prev_hash: CryptoHash, - ) -> Option<(u64, ShardSyncDownload)> { - let shard_id = *shard_id; + ) -> Option<(ShardId, ShardSyncDownload)> { if self.shard_tracker.care_about_shard(me.as_ref(), &prev_hash, shard_id, true) { let shard_sync_download = ShardSyncDownload { downloads: vec![], diff --git a/chain/client/src/client_actor.rs b/chain/client/src/client_actor.rs index f43121d4a3d..ea909642fac 100644 --- a/chain/client/src/client_actor.rs +++ b/chain/client/src/client_actor.rs @@ -65,7 +65,7 @@ use near_primitives::block::Tip; use near_primitives::block_header::ApprovalType; use near_primitives::hash::CryptoHash; use near_primitives::network::{AnnounceAccount, PeerId}; -use near_primitives::types::{AccountId, BlockHeight, EpochId}; +use near_primitives::types::{AccountId, BlockHeight, EpochId, ShardId}; use near_primitives::unwrap_or_return; use near_primitives::utils::MaybeValidated; use near_primitives::validator_signer::ValidatorSigner; @@ -1912,7 +1912,7 @@ impl ClientActorInner { &mut self, epoch_id: EpochId, sync_hash: CryptoHash, - shards_to_sync: &Vec, + shards_to_sync: &Vec, ) { let shard_layout = self.client.epoch_manager.get_shard_layout(&epoch_id).expect("Cannot get shard layout"); diff --git a/chain/client/src/debug.rs b/chain/client/src/debug.rs index 47ca3fc1e4d..c3c59579848 100644 --- a/chain/client/src/debug.rs +++ b/chain/client/src/debug.rs @@ -21,7 +21,9 @@ use near_performance_metrics_macros::perf; use near_primitives::congestion_info::CongestionControl; use near_primitives::state_sync::get_num_state_parts; use near_primitives::stateless_validation::chunk_endorsement::ChunkEndorsement; -use near_primitives::types::{AccountId, BlockHeight, NumShards, ShardId, ValidatorInfoIdentifier}; +use near_primitives::types::{ + AccountId, BlockHeight, NumShards, ShardId, ShardIndex, ValidatorInfoIdentifier, +}; use near_primitives::{ hash::CryptoHash, state_sync::{ShardStateSyncResponseHeader, StateHeaderKey}, @@ -104,11 +106,11 @@ impl BlockProductionTracker { /// Record chunk collected after a block is produced if the block didn't include a chunk for the shard. /// If called before the block was produced, nothing happens. - pub(crate) fn record_chunk_collected(&mut self, height: BlockHeight, shard_id: ShardId) { + pub(crate) fn record_chunk_collected(&mut self, height: BlockHeight, shard_index: ShardIndex) { if let Some(block_production) = self.0.get_mut(&height) { let chunk_collections = &mut block_production.chunks_collection_time; // Check that chunk_collection is set and we haven't received this chunk yet. - if let Some(chunk_collection) = chunk_collections.get_mut(shard_id as usize) { + if let Some(chunk_collection) = chunk_collections.get_mut(shard_index) { if chunk_collection.received_time.is_none() { chunk_collection.received_time = Some(Clock::real().now_utc()); } @@ -121,13 +123,15 @@ impl BlockProductionTracker { pub(crate) fn construct_chunk_collection_info( block_height: BlockHeight, epoch_id: &EpochId, - num_shards: ShardId, + num_shards: usize, new_chunks: &HashMap, epoch_manager: &dyn EpochManagerAdapter, chunk_inclusion_tracker: &ChunkInclusionTracker, ) -> Result, Error> { let mut chunk_collection_info = vec![]; - for shard_id in 0..num_shards { + for shard_index in 0..num_shards { + let shard_layout = epoch_manager.get_shard_layout(epoch_id)?; + let shard_id = shard_layout.get_shard_id(shard_index); if let Some(chunk_hash) = new_chunks.get(&shard_id) { let (chunk_producer, received_time) = chunk_inclusion_tracker.get_chunk_producer_and_received_time(chunk_hash)?; @@ -228,6 +232,8 @@ impl ClientActorInner { let block = self.client.chain.get_block_by_height(epoch_start_height)?; let epoch_id = block.header().epoch_id(); + let shard_layout = self.client.epoch_manager.get_shard_layout(&epoch_id)?; + let (validators, chunk_only_producers) = self.get_producers_for_epoch(&epoch_id, ¤t_block)?; @@ -235,9 +241,10 @@ impl ClientActorInner { .chunks() .iter() .enumerate() - .map(|(shard_id, chunk)| { + .map(|(shard_index, chunk)| { + let shard_id = shard_layout.get_shard_id(shard_index); let state_root_node = self.client.runtime_adapter.get_state_root_node( - shard_id as u64, + shard_id, block.hash(), &chunk.prev_state_root(), ); @@ -252,9 +259,10 @@ impl ClientActorInner { }) .collect(); - let state_header_exists: Vec = (0..block.chunks().len()) + let state_header_exists: Vec = shard_layout + .shard_ids() .map(|shard_id| { - let key = borsh::to_vec(&StateHeaderKey(shard_id as u64, *block.hash())); + let key = borsh::to_vec(&StateHeaderKey(shard_id, *block.hash())); match key { Ok(key) => { matches!( @@ -490,7 +498,7 @@ impl ClientActorInner { }); DebugChunkStatus { - shard_id: chunk.shard_id(), + shard_id: chunk.shard_id().into(), chunk_hash: chunk.chunk_hash(), chunk_producer: self .client diff --git a/chain/client/src/info.rs b/chain/client/src/info.rs index 3f3d3d8e10e..ab62c3b9d3d 100644 --- a/chain/client/src/info.rs +++ b/chain/client/src/info.rs @@ -210,6 +210,7 @@ impl InfoHelper { let epoch_info = client.epoch_manager.get_epoch_info(&head.epoch_id); let blocks_in_epoch = client.config.epoch_length; let shard_ids = client.epoch_manager.shard_ids(&head.epoch_id).unwrap_or_default(); + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); if let Ok(epoch_info) = epoch_info { metrics::VALIDATORS_CHUNKS_EXPECTED_IN_EPOCH.reset(); metrics::VALIDATORS_BLOCKS_EXPECTED_IN_EPOCH.reset(); @@ -250,10 +251,11 @@ impl InfoHelper { }); for shard_id in shard_ids { + let shard_index = shard_layout.get_shard_index(shard_id); let mut stake_per_cp = HashMap::::new(); stake_sum = 0; let chunk_producers_settlement = &epoch_info.chunk_producers_settlement(); - let chunk_producers = chunk_producers_settlement.get(shard_id as usize); + let chunk_producers = chunk_producers_settlement.get(shard_index); let Some(chunk_producers) = chunk_producers else { tracing::warn!(target: "stats", ?shard_id, ?chunk_producers_settlement, "invalid shard id, not found in the shard settlement"); continue; diff --git a/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs b/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs index 30dc76f41ff..44b2954c147 100644 --- a/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs +++ b/chain/client/src/stateless_validation/chunk_endorsement/tracker_v1.rs @@ -90,7 +90,7 @@ impl ChunkEndorsementTracker { chunk_header: &ShardChunkHeader, endorsement: ChunkEndorsementV1, ) -> Result<(), Error> { - let _span = tracing::debug_span!(target: "client", "process_chunk_endorsement", chunk_hash=?chunk_header.chunk_hash(), shard_id=chunk_header.shard_id()).entered(); + let _span = tracing::debug_span!(target: "client", "process_chunk_endorsement", chunk_hash=?chunk_header.chunk_hash(), shard_id=?chunk_header.shard_id()).entered(); // Validate the endorsement before locking the mutex to improve performance. if !self.epoch_manager.verify_chunk_endorsement(&chunk_header, &endorsement)? { tracing::error!(target: "client", ?endorsement, "Invalid chunk endorsement."); diff --git a/chain/client/src/stateless_validation/chunk_validator/mod.rs b/chain/client/src/stateless_validation/chunk_validator/mod.rs index 6395464efca..23e6fdf43bb 100644 --- a/chain/client/src/stateless_validation/chunk_validator/mod.rs +++ b/chain/client/src/stateless_validation/chunk_validator/mod.rs @@ -214,7 +214,7 @@ pub(crate) fn send_chunk_endorsement_to_block_producers( tracing::debug!( target: "client", chunk_hash=?chunk_hash, - shard_id=chunk_header.shard_id(), + shard_id=?chunk_header.shard_id(), ?block_producers, "send_chunk_endorsement", ); @@ -243,7 +243,7 @@ impl Client { tracing::debug!( target: "client", chunk_hash=?witness.chunk_header.chunk_hash(), - shard_id=witness.chunk_header.shard_id(), + shard_id=?witness.chunk_header.shard_id(), "process_chunk_state_witness", ); diff --git a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs index ed68fa977ad..96a55131d2b 100644 --- a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs +++ b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_handling.rs @@ -35,7 +35,7 @@ impl Client { let _span = tracing::debug_span!(target: "client", "handle_orphan_state_witness", witness_height, - witness_shard, + ?witness_shard, witness_chunk = ?chunk_header.chunk_hash(), witness_prev_block = ?chunk_header.prev_block_hash(), ) @@ -63,7 +63,7 @@ impl Client { tracing::warn!( target: "client", witness_height, - witness_shard, + ?witness_shard, witness_chunk = ?chunk_header.chunk_hash(), witness_prev_block = ?chunk_header.prev_block_hash(), witness_size, @@ -87,7 +87,7 @@ impl Client { tracing::debug!( target: "client", witness_height = header.height_created(), - witness_shard = header.shard_id(), + witness_shard = ?header.shard_id(), witness_chunk = ?header.chunk_hash(), witness_prev_block = ?header.prev_block_hash(), "Processing an orphaned ChunkStateWitness, its previous block has arrived." diff --git a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs index 0429875b794..0a15da403e5 100644 --- a/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs +++ b/chain/client/src/stateless_validation/chunk_validator/orphan_witness_pool.rs @@ -56,7 +56,7 @@ impl OrphanStateWitnessPool { tracing::debug!( target: "client", ejected_witness_height = header.height_created(), - ejected_witness_shard = header.shard_id(), + ejected_witness_shard = ?header.shard_id(), ejected_witness_chunk = ?header.chunk_hash(), ejected_witness_prev_block = ?header.prev_block_hash(), "Ejecting an orphaned ChunkStateWitness from the cache due to capacity limit. It will not be processed." @@ -101,7 +101,7 @@ impl OrphanStateWitnessPool { target: "client", final_height, ejected_witness_height = witness_height, - ejected_witness_shard = cache_key.shard_id, + ejected_witness_shard = ?cache_key.shard_id, ejected_witness_chunk = ?header.chunk_hash(), ejected_witness_prev_block = ?header.prev_block_hash(), "Ejecting an orphaned ChunkStateWitness from the cache because it's below \ @@ -180,7 +180,7 @@ mod tests { use near_primitives::hash::{hash, CryptoHash}; use near_primitives::sharding::{ShardChunkHeader, ShardChunkHeaderInner}; use near_primitives::stateless_validation::state_witness::ChunkStateWitness; - use near_primitives::types::{BlockHeight, ShardId}; + use near_primitives::types::{new_shard_id_tmp, BlockHeight, ShardId}; use super::OrphanStateWitnessPool; @@ -253,10 +253,10 @@ mod tests { fn basic() { let mut pool = OrphanStateWitnessPool::new(10); - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(100, 2, block(99), 0); - let witness3 = make_witness(101, 1, block(100), 0); - let witness4 = make_witness(101, 2, block(100), 0); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(100, new_shard_id_tmp(2), block(99), 0); + let witness3 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness4 = make_witness(101, new_shard_id_tmp(2), block(100), 0); pool.add_orphan_state_witness(witness1.clone(), 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -280,8 +280,8 @@ mod tests { // The old witness is replaced when the awaited block is the same { - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(100, 1, block(99), 1); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(100, new_shard_id_tmp(1), block(99), 1); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -291,8 +291,8 @@ mod tests { // The old witness is replaced when the awaited block is different, waiting_for_block is cleaned as expected { - let witness3 = make_witness(102, 1, block(100), 0); - let witness4 = make_witness(102, 1, block(101), 0); + let witness3 = make_witness(102, new_shard_id_tmp(1), block(100), 0); + let witness4 = make_witness(102, new_shard_id_tmp(1), block(101), 0); pool.add_orphan_state_witness(witness3, 0); pool.add_orphan_state_witness(witness4.clone(), 0); @@ -311,9 +311,9 @@ mod tests { fn limited_capacity() { let mut pool = OrphanStateWitnessPool::new(2); - let witness1 = make_witness(102, 1, block(101), 0); - let witness2 = make_witness(101, 1, block(100), 0); - let witness3 = make_witness(101, 2, block(100), 0); + let witness1 = make_witness(102, new_shard_id_tmp(1), block(101), 0); + let witness2 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness3 = make_witness(101, new_shard_id_tmp(2), block(100), 0); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -337,7 +337,7 @@ mod tests { let mut pool = OrphanStateWitnessPool::new(10); let large_shard_id = ShardId::MAX; - let witness = make_witness(101, large_shard_id, block(99), 0); + let witness = make_witness(101, large_shard_id.into(), block(99), 0); pool.add_orphan_state_witness(witness.clone(), 0); let waiting_for_99 = pool.take_state_witnesses_waiting_for_block(&block(99)); @@ -351,10 +351,10 @@ mod tests { fn remove_below_height() { let mut pool = OrphanStateWitnessPool::new(10); - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(101, 1, block(100), 0); - let witness3 = make_witness(102, 1, block(101), 0); - let witness4 = make_witness(103, 1, block(102), 0); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness3 = make_witness(102, new_shard_id_tmp(1), block(101), 0); + let witness4 = make_witness(103, new_shard_id_tmp(1), block(102), 0); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2.clone(), 0); @@ -382,10 +382,10 @@ mod tests { #[test] fn destructor_doesnt_crash() { let mut pool = OrphanStateWitnessPool::new(10); - pool.add_orphan_state_witness(make_witness(100, 0, block(99), 0), 0); - pool.add_orphan_state_witness(make_witness(100, 2, block(99), 0), 0); - pool.add_orphan_state_witness(make_witness(100, 2, block(99), 0), 1); - pool.add_orphan_state_witness(make_witness(101, 0, block(100), 0), 0); + pool.add_orphan_state_witness(make_witness(100, new_shard_id_tmp(0), block(99), 0), 0); + pool.add_orphan_state_witness(make_witness(100, new_shard_id_tmp(2), block(99), 0), 0); + pool.add_orphan_state_witness(make_witness(100, new_shard_id_tmp(2), block(99), 0), 1); + pool.add_orphan_state_witness(make_witness(101, new_shard_id_tmp(0), block(100), 0), 0); std::mem::drop(pool); } @@ -395,24 +395,24 @@ mod tests { let mut pool = OrphanStateWitnessPool::new(5); // Witnesses for shards 0, 1, 2, 3 at height 1000, looking for block 99 - let witness0 = make_witness(100, 0, block(99), 0); - let witness1 = make_witness(100, 1, block(99), 0); - let witness2 = make_witness(100, 2, block(99), 0); - let witness3 = make_witness(100, 3, block(99), 0); + let witness0 = make_witness(100, new_shard_id_tmp(0), block(99), 0); + let witness1 = make_witness(100, new_shard_id_tmp(1), block(99), 0); + let witness2 = make_witness(100, new_shard_id_tmp(2), block(99), 0); + let witness3 = make_witness(100, new_shard_id_tmp(3), block(99), 0); pool.add_orphan_state_witness(witness0, 0); pool.add_orphan_state_witness(witness1, 0); pool.add_orphan_state_witness(witness2, 0); pool.add_orphan_state_witness(witness3, 0); // Another witness on shard 1, height 100. Should replace witness1 - let witness5 = make_witness(100, 1, block(99), 1); + let witness5 = make_witness(100, new_shard_id_tmp(1), block(99), 1); pool.add_orphan_state_witness(witness5.clone(), 0); // Witnesses for shards 0, 1, 2, 3 at height 101, looking for block 100 - let witness6 = make_witness(101, 0, block(100), 0); - let witness7 = make_witness(101, 1, block(100), 0); - let witness8 = make_witness(101, 2, block(100), 0); - let witness9 = make_witness(101, 3, block(100), 0); + let witness6 = make_witness(101, new_shard_id_tmp(0), block(100), 0); + let witness7 = make_witness(101, new_shard_id_tmp(1), block(100), 0); + let witness8 = make_witness(101, new_shard_id_tmp(2), block(100), 0); + let witness9 = make_witness(101, new_shard_id_tmp(3), block(100), 0); pool.add_orphan_state_witness(witness6, 0); pool.add_orphan_state_witness(witness7.clone(), 0); pool.add_orphan_state_witness(witness8.clone(), 0); @@ -424,9 +424,9 @@ mod tests { assert_contents(looking_for_99, vec![witness5]); // Let's add a few more witnesses - let witness10 = make_witness(102, 1, block(101), 0); - let witness11 = make_witness(102, 4, block(100), 0); - let witness12 = make_witness(102, 1, block(77), 0); + let witness10 = make_witness(102, new_shard_id_tmp(1), block(101), 0); + let witness11 = make_witness(102, new_shard_id_tmp(4), block(100), 0); + let witness12 = make_witness(102, new_shard_id_tmp(1), block(77), 0); pool.add_orphan_state_witness(witness10, 0); pool.add_orphan_state_witness(witness11.clone(), 0); pool.add_orphan_state_witness(witness12.clone(), 0); diff --git a/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs b/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs index 9b608f1a033..4b3ff3a8e08 100644 --- a/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs +++ b/chain/client/src/stateless_validation/partial_witness/partial_witness_tracker.rs @@ -158,7 +158,7 @@ impl PartialEncodedStateWitnessTracker { tracing::error!( target: "client", ?err, - shard_id = key.shard_id, + shard_id = ?key.shard_id, height_created = key.height_created, "Failed to reed solomon decode witness parts. Maybe malicious or corrupt data." ); diff --git a/chain/client/src/stateless_validation/shadow_validate.rs b/chain/client/src/stateless_validation/shadow_validate.rs index d5c240f8a07..33df5655de8 100644 --- a/chain/client/src/stateless_validation/shadow_validate.rs +++ b/chain/client/src/stateless_validation/shadow_validate.rs @@ -17,11 +17,15 @@ impl Client { tracing::debug!(target: "client", ?block_hash, "shadow validation for block chunks"); let prev_block = self.chain.get_block(block.header().prev_hash())?; let prev_block_chunks = prev_block.chunks(); - for chunk in - block.chunks().iter().filter(|chunk| chunk.is_new_chunk(block.header().height())) + for (shard_index, chunk) in block + .chunks() + .iter() + .enumerate() + .filter(|(_, chunk)| chunk.is_new_chunk(block.header().height())) { let chunk = self.chain.get_chunk_clone_from_header(chunk)?; - let prev_chunk_header = prev_block_chunks.get(chunk.shard_id() as usize).unwrap(); + // TODO(resharding) This doesn't work if shard layout changes. + let prev_chunk_header = prev_block_chunks.get(shard_index).unwrap(); if let Err(err) = self.shadow_validate_chunk(prev_block.header(), prev_chunk_header, &chunk) { @@ -30,7 +34,7 @@ impl Client { tracing::error!( target: "client", ?err, - shard_id = chunk.shard_id(), + shard_id = ?chunk.shard_id(), ?block_hash, "shadow chunk validation failed" ); diff --git a/chain/client/src/stateless_validation/state_witness_producer.rs b/chain/client/src/stateless_validation/state_witness_producer.rs index 94ade12c855..ce714155d29 100644 --- a/chain/client/src/stateless_validation/state_witness_producer.rs +++ b/chain/client/src/stateless_validation/state_witness_producer.rs @@ -256,15 +256,14 @@ impl Client { let mut source_receipt_proofs = HashMap::new(); for receipt_proof_response in incoming_receipt_proofs { let from_block = self.chain.chain_store().get_block(&receipt_proof_response.0)?; + let shard_layout = + self.epoch_manager.get_shard_layout(from_block.header().epoch_id())?; for proof in receipt_proof_response.1.iter() { - let from_shard_id: usize = proof - .1 - .from_shard_id - .try_into() - .map_err(|_| Error::Other("Couldn't convert u64 to usize!".into()))?; + let from_shard_id = proof.1.from_shard_id; + let from_shard_index = shard_layout.get_shard_index(from_shard_id); let from_chunk_hash = from_block .chunks() - .get(from_shard_id) + .get(from_shard_index) .ok_or_else(|| Error::InvalidShardId(proof.1.from_shard_id))? .chunk_hash(); let insert_res = diff --git a/chain/client/src/stateless_validation/state_witness_tracker.rs b/chain/client/src/stateless_validation/state_witness_tracker.rs index 2297e98a6d6..c1d8face71f 100644 --- a/chain/client/src/stateless_validation/state_witness_tracker.rs +++ b/chain/client/src/stateless_validation/state_witness_tracker.rs @@ -153,7 +153,7 @@ mod state_witness_tracker_tests { use near_async::time::{Duration, FakeClock, Utc}; use near_primitives::hash::hash; use near_primitives::stateless_validation::state_witness::ChunkStateWitness; - use near_primitives::types::ShardId; + use near_primitives::types::new_shard_id_tmp; const NUM_VALIDATORS: usize = 3; @@ -205,7 +205,7 @@ mod state_witness_tracker_tests { } fn dummy_witness() -> ChunkStateWitness { - ChunkStateWitness::new_dummy(100, 2 as ShardId, hash("fake hash".as_bytes())) + ChunkStateWitness::new_dummy(100, new_shard_id_tmp(2), hash("fake hash".as_bytes())) } fn dummy_clock() -> FakeClock { diff --git a/chain/client/src/sync/external.rs b/chain/client/src/sync/external.rs index af8c08ec510..451a95e43a2 100644 --- a/chain/client/src/sync/external.rs +++ b/chain/client/src/sync/external.rs @@ -143,7 +143,7 @@ impl ExternalConnection { match self { ExternalConnection::S3 { bucket } => { bucket.put_object(&location, data).await?; - tracing::debug!(target: "state_sync_dump", shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to S3"); + tracing::debug!(target: "state_sync_dump", ?shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to S3"); Ok(()) } ExternalConnection::Filesystem { root_dir } => { @@ -157,7 +157,7 @@ impl ExternalConnection { .truncate(true) .open(&path)?; file.write_all(data)?; - tracing::debug!(target: "state_sync_dump", shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to a file"); + tracing::debug!(target: "state_sync_dump", ?shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to a file"); Ok(()) } ExternalConnection::GCS { gcs_client, bucket, .. } => { @@ -165,7 +165,7 @@ impl ExternalConnection { .object() .create(bucket, data.to_vec(), location, "application/octet-stream") .await?; - tracing::debug!(target: "state_sync_dump", shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to GCS"); + tracing::debug!(target: "state_sync_dump", ?shard_id, part_length = data.len(), ?location, ?file_type, "Wrote a state part to GCS"); Ok(()) } } @@ -194,7 +194,7 @@ impl ExternalConnection { ExternalConnection::S3 { bucket } => { let prefix = format!("{}/", directory_path); let list_results = bucket.list(prefix.clone(), Some("/".to_string())).await?; - tracing::debug!(target: "state_sync_dump", shard_id, ?directory_path, "List state parts in s3"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?directory_path, "List state parts in s3"); let mut file_names = vec![]; for res in list_results { for obj in res.contents { @@ -205,7 +205,7 @@ impl ExternalConnection { } ExternalConnection::Filesystem { root_dir } => { let path = root_dir.join(directory_path); - tracing::debug!(target: "state_sync_dump", shard_id, ?path, "List state parts in local directory"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?path, "List state parts in local directory"); std::fs::create_dir_all(&path)?; let mut file_names = vec![]; let files = std::fs::read_dir(&path)?; @@ -217,7 +217,7 @@ impl ExternalConnection { } ExternalConnection::GCS { gcs_client, bucket, .. } => { let prefix = format!("{}/", directory_path); - tracing::debug!(target: "state_sync_dump", shard_id, ?directory_path, "List state parts in GCS"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?directory_path, "List state parts in GCS"); Ok(gcs_client .object() .list( @@ -277,7 +277,7 @@ pub fn external_storage_location( chain_id: &str, epoch_id: &EpochId, epoch_height: u64, - shard_id: u64, + shard_id: ShardId, file_type: &StateFileType, ) -> String { format!( @@ -291,7 +291,7 @@ pub fn external_storage_location_directory( chain_id: &str, epoch_id: &EpochId, epoch_height: u64, - shard_id: u64, + shard_id: ShardId, obj_type: &StateFileType, ) -> String { location_prefix(chain_id, epoch_height, epoch_id, shard_id, obj_type) @@ -301,7 +301,7 @@ pub fn location_prefix( chain_id: &str, epoch_height: u64, epoch_id: &EpochId, - shard_id: u64, + shard_id: ShardId, obj_type: &StateFileType, ) -> String { match obj_type { @@ -410,6 +410,7 @@ mod test { ExternalConnection, StateFileType, }; use near_o11y::testonly::init_test_logger; + use near_primitives::types::new_shard_id_tmp; use rand::distributions::{Alphanumeric, DistString}; fn random_string(rand_len: usize) -> String { @@ -460,31 +461,38 @@ mod test { let file_type = StateFileType::StatePart { part_id: 0, num_parts: 1 }; // Before uploading we shouldn't see filename in the list of files. - let files = rt.block_on(async { connection.list_objects(0, &dir).await.unwrap() }); + let files = rt + .block_on(async { connection.list_objects(new_shard_id_tmp(0), &dir).await.unwrap() }); tracing::debug!("Files before upload: {:?}", files); assert_eq!(files.into_iter().filter(|x| *x == filename).collect::>().len(), 0); // Uploading the file. rt.block_on(async { - connection.put_file(file_type.clone(), &data, 0, &full_filename).await.unwrap() + connection + .put_file(file_type.clone(), &data, new_shard_id_tmp(0), &full_filename) + .await + .unwrap() }); // After uploading we should see filename in the list of files. - let files = rt.block_on(async { connection.list_objects(0, &dir).await.unwrap() }); + let files = rt + .block_on(async { connection.list_objects(new_shard_id_tmp(0), &dir).await.unwrap() }); tracing::debug!("Files after upload: {:?}", files); assert_eq!(files.into_iter().filter(|x| *x == filename).collect::>().len(), 1); // And the data should match generates data. - let download_data = rt - .block_on(async { connection.get_file(0, &full_filename, &file_type).await.unwrap() }); + let download_data = rt.block_on(async { + connection.get_file(new_shard_id_tmp(0), &full_filename, &file_type).await.unwrap() + }); assert_eq!(download_data, data); // Also try to download some data at nonexistent location and expect to fail. let filename = random_string(8); let full_filename = format!("{}/{}", dir, filename); - let download_data = - rt.block_on(async { connection.get_file(0, &full_filename, &file_type).await }); + let download_data = rt.block_on(async { + connection.get_file(new_shard_id_tmp(0), &full_filename, &file_type).await + }); assert!(download_data.is_err(), "{:?}", download_data); } } diff --git a/chain/client/src/sync/state.rs b/chain/client/src/sync/state.rs index d4dccac3e7e..48d74ff5236 100644 --- a/chain/client/src/sync/state.rs +++ b/chain/client/src/sync/state.rs @@ -49,7 +49,9 @@ use near_primitives::state_part::PartId; use near_primitives::state_sync::{ ShardStateSyncResponse, ShardStateSyncResponseHeader, StatePartKey, }; -use near_primitives::types::{AccountId, EpochHeight, EpochId, ShardId, StateRoot}; +use near_primitives::types::{ + shard_id_as_u32, AccountId, EpochHeight, EpochId, ShardId, StateRoot, +}; use near_store::DBCol; use rand::seq::SliceRandom; use rand::thread_rng; @@ -204,7 +206,7 @@ impl StateSync { &mut self, me: &Option, sync_hash: CryptoHash, - sync_status: &mut HashMap, + sync_status: &mut HashMap, chain: &mut Chain, epoch_manager: &dyn EpochManagerAdapter, highest_height_peers: &[HighestHeightPeerInfo], @@ -232,7 +234,7 @@ impl StateSync { for shard_id in tracking_shards { let version = prev_shard_layout.version(); - let shard_uid = ShardUId { version, shard_id: shard_id as u32 }; + let shard_uid = ShardUId { version, shard_id: shard_id_as_u32(shard_id) }; let mut download_timeout = false; let mut run_shard_state_download = false; let shard_sync_download = sync_status.entry(shard_id).or_insert_with(|| { @@ -344,7 +346,7 @@ impl StateSync { &mut self, chain: &mut Chain, sync_hash: CryptoHash, - shard_sync: &mut HashMap, + shard_sync: &mut HashMap, ) { for StateSyncGetFileResult { sync_hash: msg_sync_hash, shard_id, part_id, result } in self.state_parts_mpsc_rx.try_iter() @@ -539,7 +541,7 @@ impl StateSync { // Currently it is assumed that one of the direct peers of the node is able to generate // the shard header. let peer_id = possible_targets.choose(&mut thread_rng()).cloned().unwrap(); - tracing::debug!(target: "sync", ?peer_id, shard_id, ?sync_hash, ?possible_targets, "request_shard_header"); + tracing::debug!(target: "sync", ?peer_id, ?shard_id, ?sync_hash, ?possible_targets, "request_shard_header"); assert!(header_download.run_me.load(Ordering::SeqCst)); header_download.run_me.store(false, Ordering::SeqCst); header_download.state_requests_count += 1; @@ -668,7 +670,7 @@ impl StateSync { &mut self, me: &Option, sync_hash: CryptoHash, - sync_status: &mut HashMap, + sync_status: &mut HashMap, chain: &mut Chain, epoch_manager: &dyn EpochManagerAdapter, highest_height_peers: &[HighestHeightPeerInfo], @@ -722,7 +724,7 @@ impl StateSync { &mut self, shard_sync_download: &mut ShardSyncDownload, hash: CryptoHash, - shard_id: u64, + shard_id: ShardId, state_response: ShardStateSyncResponse, chain: &mut Chain, ) { @@ -1314,6 +1316,7 @@ mod test { use near_primitives::state_sync::{ CachedParts, ShardStateSyncResponseHeader, ShardStateSyncResponseV3, }; + use near_primitives::types::new_shard_id_tmp; use near_primitives::{test_utils::TestBlockBuilder, types::EpochId}; #[test] @@ -1356,7 +1359,8 @@ mod test { } let request_hash = &chain.head().unwrap().last_block_hash; - let state_sync_header = chain.get_state_response_header(0, *request_hash).unwrap(); + let state_sync_header = + chain.get_state_response_header(new_shard_id_tmp(0), *request_hash).unwrap(); let state_sync_header = match state_sync_header { ShardStateSyncResponseHeader::V1(_) => panic!("Invalid header"), ShardStateSyncResponseHeader::V2(internal) => internal, @@ -1370,7 +1374,7 @@ mod test { genesis_id: Default::default(), highest_block_height: chain.epoch_length + 10, highest_block_hash: Default::default(), - tracked_shards: vec![0], + tracked_shards: vec![new_shard_id_tmp(0)], archival: false, }; @@ -1383,7 +1387,7 @@ mod test { &mut chain, kv.as_ref(), &[highest_height_peer_info], - vec![0], + vec![new_shard_id_tmp(0)], &noop().into_sender(), &noop().into_sender(), &ActixArbiterHandleFutureSpawner(Arbiter::new().handle()), @@ -1398,7 +1402,7 @@ mod test { assert_eq!( NetworkRequests::StateRequestHeader { - shard_id: 0, + shard_id: new_shard_id_tmp(0), sync_hash: *request_hash, peer_id: peer_id.clone(), }, @@ -1406,7 +1410,7 @@ mod test { ); assert_eq!(1, new_shard_sync.len()); - let download = new_shard_sync.get(&0).unwrap(); + let download = new_shard_sync.get(&new_shard_id_tmp(0)).unwrap(); assert_eq!(download.status, ShardSyncStatus::StateDownloadHeader); @@ -1430,14 +1434,14 @@ mod test { }); state_sync.update_download_on_state_response_message( - &mut new_shard_sync.get_mut(&0).unwrap(), + &mut new_shard_sync.get_mut(&new_shard_id_tmp(0)).unwrap(), *request_hash, - 0, + new_shard_id_tmp(0), state_response, &mut chain, ); - let download = new_shard_sync.get(&0).unwrap(); + let download = new_shard_sync.get(&new_shard_id_tmp(0)).unwrap(); assert_eq!(download.status, ShardSyncStatus::StateDownloadHeader); // Download should be marked as done. assert_eq!(download.downloads[0].done, true); diff --git a/chain/client/src/sync_jobs_actor.rs b/chain/client/src/sync_jobs_actor.rs index 176151823ad..5d7e854fb3a 100644 --- a/chain/client/src/sync_jobs_actor.rs +++ b/chain/client/src/sync_jobs_actor.rs @@ -9,7 +9,6 @@ use near_chain::chain::{ use near_performance_metrics_macros::perf; use near_primitives::state_part::PartId; use near_primitives::state_sync::StatePartKey; -use near_primitives::types::ShardId; use near_store::adapter::StoreUpdateAdapter; use near_store::DBCol; @@ -73,7 +72,7 @@ impl SyncJobsActor { tracing::debug_span!(target: "sync", "apply_parts").entered(); let store = msg.runtime_adapter.store(); - let shard_id = msg.shard_uid.shard_id as ShardId; + let shard_id = msg.shard_uid.shard_id(); for part_id in 0..msg.num_parts { let key = borsh::to_vec(&StatePartKey(msg.sync_hash, shard_id, part_id))?; let part = store.get(DBCol::StateParts, &key)?.unwrap(); @@ -124,7 +123,7 @@ impl SyncJobsActor { // Unload mem-trie (in case it is still loaded) before we apply state parts. msg.runtime_adapter.get_tries().unload_mem_trie(&msg.shard_uid); - let shard_id = msg.shard_uid.shard_id as ShardId; + let shard_id = msg.shard_uid.shard_id(); match self.clear_flat_state(&msg) { Err(err) => { self.client_sender.send(ApplyStatePartsResponse { diff --git a/chain/client/src/test_utils/client.rs b/chain/client/src/test_utils/client.rs index f8e8063738b..639e946e116 100644 --- a/chain/client/src/test_utils/client.rs +++ b/chain/client/src/test_utils/client.rs @@ -22,7 +22,7 @@ use near_primitives::merkle::{merklize, PartialMerkleTree}; use near_primitives::sharding::{EncodedShardChunk, ShardChunk}; use near_primitives::stateless_validation::chunk_endorsement::ChunkEndorsementV1; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{BlockHeight, ShardId}; +use near_primitives::types::{new_shard_id_tmp, BlockHeight, ShardId}; use near_primitives::utils::MaybeValidated; use near_primitives::version::PROTOCOL_VERSION; use num_rational::Ratio; @@ -159,7 +159,7 @@ fn create_chunk_on_height_for_shard( } pub fn create_chunk_on_height(client: &mut Client, next_height: BlockHeight) -> ProduceChunkResult { - create_chunk_on_height_for_shard(client, next_height, 0) + create_chunk_on_height_for_shard(client, next_height, new_shard_id_tmp(0)) } pub fn create_chunk_with_transactions( @@ -190,7 +190,7 @@ pub fn create_chunk( last_block.header().epoch_id(), last_block.chunks()[0].clone(), next_height, - 0, + new_shard_id_tmp(0), signer.as_ref(), ) .unwrap() diff --git a/chain/client/src/test_utils/setup.rs b/chain/client/src/test_utils/setup.rs index c6c669dc8c5..c8490c9070d 100644 --- a/chain/client/src/test_utils/setup.rs +++ b/chain/client/src/test_utils/setup.rs @@ -56,7 +56,9 @@ use near_primitives::epoch_info::RngSeed; use near_primitives::hash::{hash, CryptoHash}; use near_primitives::network::PeerId; use near_primitives::test_utils::create_test_signer; -use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, NumBlocks, NumSeats}; +use near_primitives::types::{ + new_shard_id_tmp, AccountId, BlockHeightDelta, EpochId, NumBlocks, NumSeats, +}; use near_primitives::validator_signer::{EmptyValidatorSigner, ValidatorSigner}; use near_primitives::version::PROTOCOL_VERSION; use near_store::adapter::StoreAdapter; @@ -448,7 +450,10 @@ fn process_peer_manager_message_default( height: last_height[i], hash: CryptoHash::default(), }), - tracked_shards: vec![0, 1, 2, 3], + tracked_shards: vec![0, 1, 2, 3] + .into_iter() + .map(new_shard_id_tmp) + .collect(), archival: true, }, }, diff --git a/chain/client/src/test_utils/test_env.rs b/chain/client/src/test_utils/test_env.rs index 4d612566f48..770d8f65c15 100644 --- a/chain/client/src/test_utils/test_env.rs +++ b/chain/client/src/test_utils/test_env.rs @@ -524,8 +524,10 @@ impl TestEnv { let last_block = client.chain.get_block(&head.last_block_hash).unwrap(); let shard_id = client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap(); + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap(); - let last_chunk_header = &last_block.chunks()[shard_id as usize]; + let last_chunk_header = &last_block.chunks()[shard_index]; for i in 0..self.clients.len() { let tracks_shard = self.clients[i] @@ -582,7 +584,9 @@ impl TestEnv { let shard_id = client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap(); let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap(); - let last_chunk_header = &last_block.chunks()[shard_id as usize]; + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); + let last_chunk_header = &last_block.chunks()[shard_index]; let response = client .runtime_adapter .query( diff --git a/chain/client/src/test_utils/test_loop.rs b/chain/client/src/test_utils/test_loop.rs index 0dd8b8c7767..cb0c461a4da 100644 --- a/chain/client/src/test_utils/test_loop.rs +++ b/chain/client/src/test_utils/test_loop.rs @@ -58,7 +58,9 @@ where let shard_id = client.epoch_manager.account_id_to_shard_id(&account_id, &head.epoch_id).unwrap(); let shard_uid = client.epoch_manager.shard_id_to_uid(shard_id, &head.epoch_id).unwrap(); - let last_chunk_header = &last_block.chunks()[shard_id as usize]; + let shard_layout = client.epoch_manager.get_shard_layout(&head.epoch_id).unwrap(); + let shard_index = shard_layout.get_shard_index(shard_id); + let last_chunk_header = &last_block.chunks()[shard_index]; client .runtime_adapter diff --git a/chain/client/src/tests/bug_repros.rs b/chain/client/src/tests/bug_repros.rs index b239d69c965..5f14ff54e87 100644 --- a/chain/client/src/tests/bug_repros.rs +++ b/chain/client/src/tests/bug_repros.rs @@ -109,28 +109,30 @@ fn repro_1183() { for from in ["test1", "test2", "test3", "test4"].iter() { for to in ["test1", "test2", "test3", "test4"].iter() { let (from, to) = (from.parse().unwrap(), to.parse().unwrap()); - connectors1.write().unwrap()[account_id_to_shard_id(&from, 4) as usize] - .client_actor - .do_send( - ProcessTxRequest { - transaction: SignedTransaction::send_money( - block.header().height() * 16 + nonce_delta, + // This test uses the V0 shard layout so it's ok to + // cast ShardId to ShardIndex. + let shard_id = account_id_to_shard_id(&from, 4); + let shard_index = shard_id as usize; + connectors1.write().unwrap()[shard_index].client_actor.do_send( + ProcessTxRequest { + transaction: SignedTransaction::send_money( + block.header().height() * 16 + nonce_delta, + from.clone(), + to, + &InMemorySigner::from_seed( from.clone(), - to, - &InMemorySigner::from_seed( - from.clone(), - KeyType::ED25519, - from.as_ref(), - ) - .into(), - 1, - *block.header().prev_hash(), - ), - is_forwarded: false, - check_only: false, - } - .with_span_context(), - ); + KeyType::ED25519, + from.as_ref(), + ) + .into(), + 1, + *block.header().prev_hash(), + ), + is_forwarded: false, + check_only: false, + } + .with_span_context(), + ); nonce_delta += 1 } } diff --git a/chain/client/src/tests/catching_up.rs b/chain/client/src/tests/catching_up.rs index f49db6989ec..16f78a624d8 100644 --- a/chain/client/src/tests/catching_up.rs +++ b/chain/client/src/tests/catching_up.rs @@ -23,7 +23,7 @@ use near_primitives::network::PeerId; use near_primitives::receipt::Receipt; use near_primitives::sharding::ChunkHash; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{AccountId, BlockHeight, BlockHeightDelta, BlockReference}; +use near_primitives::types::{AccountId, BlockHeight, BlockHeightDelta, BlockReference, ShardId}; use near_primitives::views::QueryRequest; use near_primitives::views::QueryResponseKind::ViewAccount; @@ -99,7 +99,7 @@ enum ReceiptsSyncPhases { #[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)] pub struct StateRequestStruct { - pub shard_id: u64, + pub shard_id: ShardId, pub sync_hash: CryptoHash, pub sync_prev_prev_hash: Option, pub part_id: Option, @@ -714,7 +714,8 @@ fn test_all_chunks_accepted_common( let verbose = false; - let seen_chunk_same_sender = Arc::new(RwLock::new(HashSet::<(AccountId, u64, u64)>::new())); + let seen_chunk_same_sender = + Arc::new(RwLock::new(HashSet::<(AccountId, u64, ShardId)>::new())); let requested = Arc::new(RwLock::new(HashSet::<(AccountId, Vec, ChunkHash)>::new())); let responded = Arc::new(RwLock::new(HashSet::<(CryptoHash, Vec, ChunkHash)>::new())); diff --git a/chain/client/src/tests/cross_shard_tx.rs b/chain/client/src/tests/cross_shard_tx.rs index 7467758d81f..8a26ad41f54 100644 --- a/chain/client/src/tests/cross_shard_tx.rs +++ b/chain/client/src/tests/cross_shard_tx.rs @@ -189,8 +189,11 @@ fn test_cross_shard_tx_callback( let balances1 = balances; let observed_balances1 = observed_balances; let presumable_epoch1 = presumable_epoch.clone(); - let actor = &connectors_[account_id_to_shard_id(&account_id, 8) as usize - + (*presumable_epoch.read().unwrap() * 8) % 24] + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&account_id, 8); + let shard_index = shard_id as usize; + let actor = &connectors_[shard_index + (*presumable_epoch.read().unwrap() * 8) % 24] .view_client_actor; let actor = actor.send( Query::new( @@ -254,10 +257,15 @@ fn test_cross_shard_tx_callback( let amount = (5 + iteration_local) as u128; let next_nonce = nonce.fetch_add(1, Ordering::Relaxed); + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&validators[from], 8); + let shard_index = shard_id as usize; + send_tx( validators.len(), connectors.clone(), - account_id_to_shard_id(&validators[from], 8) as usize, + shard_index, validators[from].clone(), validators[to].clone(), amount, @@ -287,8 +295,14 @@ fn test_cross_shard_tx_callback( let presumable_epoch1 = presumable_epoch.clone(); let account_id1 = validators[i].clone(); let block_stats1 = block_stats.clone(); - let actor = &connectors_[account_id_to_shard_id(&validators[i], 8) as usize - + (*presumable_epoch.read().unwrap() * 8) % 24] + + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&validators[i], 8); + let shard_index = shard_id as usize; + + let actor = &connectors_ + [shard_index + (*presumable_epoch.read().unwrap() * 8) % 24] .view_client_actor; let actor = actor.send( Query::new( @@ -341,8 +355,13 @@ fn test_cross_shard_tx_callback( let connectors_ = connectors.write().unwrap(); let connectors1 = connectors.clone(); let presumable_epoch1 = presumable_epoch.clone(); - let actor = &connectors_[account_id_to_shard_id(&account_id, 8) as usize - + (*presumable_epoch.read().unwrap() * 8) % 24] + + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&account_id, 8); + let shard_index = shard_id as usize; + + let actor = &connectors_[shard_index + (*presumable_epoch.read().unwrap() * 8) % 24] .view_client_actor; let actor = actor.send( Query::new( @@ -498,9 +517,14 @@ fn test_cross_shard_tx_common( let presumable_epoch1 = presumable_epoch.clone(); let account_id1 = validators[i].clone(); let block_stats1 = block_stats.clone(); - let actor = &connectors_[account_id_to_shard_id(&validators[i], 8) as usize - + *presumable_epoch.read().unwrap() * 8] - .view_client_actor; + + // This test uses the V0 shard layout so it's ok to cast ShardId to + // ShardIndex. + let shard_id = account_id_to_shard_id(&validators[i], 8); + let shard_index = shard_id as usize; + + let actor = + &connectors_[shard_index + *presumable_epoch.read().unwrap() * 8].view_client_actor; let actor = actor.send( Query::new( BlockReference::latest(), diff --git a/chain/client/src/tests/process_blocks.rs b/chain/client/src/tests/process_blocks.rs index 2419c0fa7d2..481e244206f 100644 --- a/chain/client/src/tests/process_blocks.rs +++ b/chain/client/src/tests/process_blocks.rs @@ -11,6 +11,7 @@ use near_primitives::network::PeerId; use near_primitives::sharding::ShardChunkHeader; use near_primitives::sharding::ShardChunkHeaderV3; use near_primitives::test_utils::create_test_signer; +use near_primitives::types::new_shard_id_tmp; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::utils::MaybeValidated; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; @@ -78,7 +79,7 @@ fn test_bad_shard_id() { chunk.encoded_merkle_root(), chunk.encoded_length(), 2, - 1, + new_shard_id_tmp(1), chunk.prev_gas_used(), chunk.gas_limit(), chunk.prev_balance_burnt(), @@ -102,7 +103,11 @@ fn test_bad_shard_id() { let err = env.clients[0] .process_block_test(MaybeValidated::from(block), Provenance::NONE) .unwrap_err(); - assert_matches!(err, near_chain::Error::InvalidShardId(1)); + if let near_chain::Error::InvalidShardId(shard_id) = err { + assert!(shard_id == new_shard_id_tmp(1)); + } else { + panic!("Expected InvalidShardId error, got {:?}", err); + } } /// Test that if a block's content (vrf_value) is corrupted, the invalid block will not affect the node's block processing diff --git a/chain/client/src/tests/query_client.rs b/chain/client/src/tests/query_client.rs index ac4a5bbbfb2..bc2214e202b 100644 --- a/chain/client/src/tests/query_client.rs +++ b/chain/client/src/tests/query_client.rs @@ -20,7 +20,7 @@ use near_primitives::block::{Block, BlockHeader}; use near_primitives::merkle::PartialMerkleTree; use near_primitives::test_utils::create_test_signer; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{BlockId, BlockReference, EpochId}; +use near_primitives::types::{new_shard_id_tmp, BlockId, BlockReference, EpochId}; use near_primitives::version::PROTOCOL_VERSION; use near_primitives::views::{QueryRequest, QueryResponseKind}; use num_rational::Ratio; @@ -210,7 +210,7 @@ fn test_execution_outcome_for_chunk() { .unwrap() .unwrap(); assert_eq!(execution_outcomes_in_block.len(), 1); - let outcomes = execution_outcomes_in_block.remove(&0).unwrap(); + let outcomes = execution_outcomes_in_block.remove(&new_shard_id_tmp(0)).unwrap(); assert_eq!(outcomes[0].id, tx_hash); System::current().stop(); }); @@ -249,7 +249,7 @@ fn test_state_request() { for _ in 0..30 { let res = view_client .send( - StateRequestHeader { shard_id: 0, sync_hash: block_hash } + StateRequestHeader { shard_id: new_shard_id_tmp(0), sync_hash: block_hash } .with_span_context(), ) .await @@ -258,14 +258,15 @@ fn test_state_request() { } // immediately query again, should be rejected + let shard_id = new_shard_id_tmp(0); let res = view_client - .send(StateRequestHeader { shard_id: 0, sync_hash: block_hash }.with_span_context()) + .send(StateRequestHeader { shard_id, sync_hash: block_hash }.with_span_context()) .await .unwrap(); assert!(res.is_none()); actix::clock::sleep(std::time::Duration::from_secs(40)).await; let res = view_client - .send(StateRequestHeader { shard_id: 0, sync_hash: block_hash }.with_span_context()) + .send(StateRequestHeader { shard_id, sync_hash: block_hash }.with_span_context()) .await .unwrap(); assert!(res.is_some()); diff --git a/chain/client/src/view_client_actor.rs b/chain/client/src/view_client_actor.rs index b5f12b26b9a..33f66d08d72 100644 --- a/chain/client/src/view_client_actor.rs +++ b/chain/client/src/view_client_actor.rs @@ -299,6 +299,7 @@ impl ViewClientActorInner { let head = self.chain.head()?; let epoch_id = self.epoch_manager.get_epoch_id(&head.last_block_hash)?; let epoch_info: Arc = self.epoch_manager.get_epoch_info(&epoch_id)?; + let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; let shard_ids = self.epoch_manager.shard_ids(&epoch_id)?; let cur_block_info = self.epoch_manager.get_block_info(&head.last_block_hash)?; let next_epoch_start_height = @@ -309,13 +310,17 @@ impl ViewClientActorInner { let mut start_block_of_window: Option = None; let last_block_of_epoch = next_epoch_start_height - 1; + // This loop does not go beyond the current epoch so it is valid to use + // the EpochInfo and ShardLayout from the current epoch. for block_height in head.height..next_epoch_start_height { let bp = epoch_info.sample_block_producer(block_height); let bp = epoch_info.get_validator(bp).account_id().clone(); let cps: Vec = shard_ids .iter() .map(|&shard_id| { - let cp = epoch_info.sample_chunk_producer(block_height, shard_id).unwrap(); + let cp = epoch_info + .sample_chunk_producer(&shard_layout, shard_id, block_height) + .unwrap(); let cp = epoch_info.get_validator(cp).account_id().clone(); cp }) @@ -751,9 +756,12 @@ fn get_chunk_from_block( shard_id: ShardId, chain: &Chain, ) -> Result { + let epoch_id = block.header().epoch_id(); + let shard_layout = chain.epoch_manager.get_shard_layout(epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); let chunk_header = block .chunks() - .get(shard_id as usize) + .get(shard_index) .ok_or_else(|| near_chain::Error::InvalidShardId(shard_id))? .clone(); let chunk_hash = chunk_header.chunk_hash(); @@ -1076,10 +1084,13 @@ impl Handler for ViewClientActorInner { let mut outcome_proof = outcome; let epoch_id = *self.chain.get_block(&outcome_proof.block_hash)?.header().epoch_id(); + let shard_layout = + self.epoch_manager.get_shard_layout(&epoch_id).into_chain_error()?; let target_shard_id = self .epoch_manager .account_id_to_shard_id(&account_id, &epoch_id) .into_chain_error()?; + let target_shard_index = shard_layout.get_shard_index(target_shard_id); let res = self.chain.get_next_block_hash_with_new_chunk( &outcome_proof.block_hash, target_shard_id, @@ -1095,7 +1106,7 @@ impl Handler for ViewClientActorInner { .iter() .map(|header| header.prev_outcome_root()) .collect::>(); - if target_shard_id >= (outcome_roots.len() as u64) { + if target_shard_index >= outcome_roots.len() { return Err(GetExecutionOutcomeError::InconsistentState { number_or_shards: outcome_roots.len(), execution_outcome_shard_id: target_shard_id, @@ -1103,8 +1114,7 @@ impl Handler for ViewClientActorInner { } Ok(GetExecutionOutcomeResponse { outcome_proof: outcome_proof.into(), - outcome_root_proof: merklize(&outcome_roots).1[target_shard_id as usize] - .clone(), + outcome_root_proof: merklize(&outcome_roots).1[target_shard_index].clone(), }) } else { Err(GetExecutionOutcomeError::NotConfirmed { transaction_or_receipt_id: id }) @@ -1361,7 +1371,7 @@ impl Handler for ViewClientActorInner { let header = match header { ShardStateSyncResponseHeader::V2(inner) => inner, _ => { - tracing::error!(target: "sync", ?sync_hash, shard_id, "Invalid state sync header format"); + tracing::error!(target: "sync", ?sync_hash, ?shard_id, "Invalid state sync header format"); return None; } }; @@ -1409,16 +1419,16 @@ impl Handler for ViewClientActorInner { let part = match self.chain.get_state_response_part(shard_id, part_id, sync_hash) { Ok(part) => Some((part_id, part)), Err(err) => { - error!(target: "sync", ?err, ?sync_hash, shard_id, part_id, "Cannot build state part"); + error!(target: "sync", ?err, ?sync_hash, ?shard_id, part_id, "Cannot build state part"); None } }; - tracing::trace!(target: "sync", ?sync_hash, shard_id, part_id, "Finished computation for state request part"); + tracing::trace!(target: "sync", ?sync_hash, ?shard_id, part_id, "Finished computation for state request part"); part } Ok(false) => { - warn!(target: "sync", ?sync_hash, shard_id, "sync_hash didn't pass validation, possible malicious behavior"); + warn!(target: "sync", ?sync_hash, ?shard_id, "sync_hash didn't pass validation, possible malicious behavior"); // Do not respond, possible malicious behavior. return None; } diff --git a/chain/epoch-manager/src/adapter.rs b/chain/epoch-manager/src/adapter.rs index 1dbe8de16a7..12cf5ac2860 100644 --- a/chain/epoch-manager/src/adapter.rs +++ b/chain/epoch-manager/src/adapter.rs @@ -18,7 +18,7 @@ use near_primitives::stateless_validation::validator_assignment::ChunkValidatorA use near_primitives::stateless_validation::ChunkProductionKey; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ - AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, ShardId, + AccountId, ApprovalStake, Balance, BlockHeight, EpochHeight, EpochId, ShardId, ShardIndex, ValidatorInfoIdentifier, }; use near_primitives::version::ProtocolVersion; @@ -118,11 +118,13 @@ pub trait EpochManagerAdapter: Send + Sync { /// resharding happened and some shards were split. /// If there was no resharding, it just returns `shard_ids` as is, without any validation. /// The resulting Vec will always be of the same length as the `shard_ids` argument. + /// + /// TODO(wacban) - rename to reflect the new return type fn get_prev_shard_ids( &self, prev_hash: &CryptoHash, shard_ids: Vec, - ) -> Result, Error>; + ) -> Result, Error>; /// For a `ShardId` in the current block, returns its parent `ShardId` /// from previous block. @@ -130,11 +132,13 @@ pub trait EpochManagerAdapter: Send + Sync { /// Most of the times parent of the shard is the shard itself, unless a /// resharding happened and some shards were split. /// If there was no resharding, it just returns the `shard_id` as is, without any validation. + /// + /// TODO(wacban) - rename to reflect the new return type fn get_prev_shard_id( &self, prev_hash: &CryptoHash, shard_id: ShardId, - ) -> Result; + ) -> Result<(ShardId, ShardIndex), Error>; /// Get shard layout given hash of previous block. fn get_shard_layout_from_prev_block( @@ -596,9 +600,9 @@ impl EpochManagerAdapter for EpochManagerHandle { &self, prev_hash: &CryptoHash, shard_ids: Vec, - ) -> Result, Error> { + ) -> Result, Error> { + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; if self.is_next_block_epoch_start(prev_hash)? { - let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; let prev_shard_layout = self.get_shard_layout(&self.get_epoch_id(prev_hash)?)?; if prev_shard_layout != shard_layout { return Ok(shard_ids @@ -611,22 +615,27 @@ impl EpochManagerAdapter for EpochManagerHandle { shard_layout, parent_shard_id ); - parent_shard_id + let parent_shard_index = prev_shard_layout.get_shard_index(parent_shard_id); + (parent_shard_id, parent_shard_index) }) }) .collect::>()?); } } - Ok(shard_ids) + + Ok(shard_ids + .iter() + .map(|&shard_id| (shard_id, shard_layout.get_shard_index(shard_id))) + .collect()) } fn get_prev_shard_id( &self, prev_hash: &CryptoHash, shard_id: ShardId, - ) -> Result { + ) -> Result<(ShardId, ShardIndex), Error> { + let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; if self.is_next_block_epoch_start(prev_hash)? { - let shard_layout = self.get_shard_layout_from_prev_block(prev_hash)?; let prev_shard_layout = self.get_shard_layout(&self.get_epoch_id(prev_hash)?)?; if prev_shard_layout != shard_layout { let parent_shard_id = shard_layout.get_parent_shard_id(shard_id)?; @@ -636,10 +645,11 @@ impl EpochManagerAdapter for EpochManagerHandle { shard_layout, parent_shard_id ); - return Ok(parent_shard_id); + let parent_shard_index = prev_shard_layout.get_shard_index(parent_shard_id); + return Ok((parent_shard_id, parent_shard_index)); } } - Ok(shard_id) + Ok((shard_id, shard_layout.get_shard_index(shard_id))) } fn get_shard_layout_from_prev_block( diff --git a/chain/epoch-manager/src/lib.rs b/chain/epoch-manager/src/lib.rs index c0c5036247d..92cdbe3b1e4 100644 --- a/chain/epoch-manager/src/lib.rs +++ b/chain/epoch-manager/src/lib.rs @@ -1087,15 +1087,17 @@ impl EpochManager { } let epoch_info = self.get_epoch_info(epoch_id)?; + let shard_layout = self.get_shard_layout(epoch_id)?; let chunk_validators_per_shard = epoch_info.sample_chunk_validators(height); - for (shard_id, chunk_validators) in chunk_validators_per_shard.into_iter().enumerate() { + for (shard_index, chunk_validators) in chunk_validators_per_shard.into_iter().enumerate() { let chunk_validators = chunk_validators .into_iter() .map(|(validator_id, assignment_weight)| { (epoch_info.get_validator(validator_id).take_account_id(), assignment_weight) }) .collect(); - let cache_key = (*epoch_id, shard_id as ShardId, height); + let shard_id = shard_layout.get_shard_id(shard_index); + let cache_key = (*epoch_id, shard_id, height); self.chunk_validators_cache .put(cache_key, Arc::new(ChunkValidatorAssignments::new(chunk_validators))); } @@ -1180,7 +1182,9 @@ impl EpochManager { shard_id: ShardId, ) -> Result { let epoch_info = self.get_epoch_info(epoch_id)?; - let validator_id = Self::chunk_producer_from_info(&epoch_info, height, shard_id)?; + let shard_layout = self.get_shard_layout(epoch_id)?; + let validator_id = + Self::chunk_producer_from_info(&epoch_info, &shard_layout, shard_id, height)?; Ok(epoch_info.get_validator(validator_id)) } @@ -1239,9 +1243,13 @@ impl EpochManager { shard_id: ShardId, ) -> Result { let epoch_info = self.get_epoch_info(&epoch_id)?; + + let shard_layout = self.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let chunk_producers_settlement = epoch_info.chunk_producers_settlement(); let chunk_producers = chunk_producers_settlement - .get(shard_id as usize) + .get(shard_index) .ok_or_else(|| EpochError::ShardingError(format!("invalid shard id {shard_id}")))?; for validator_id in chunk_producers.iter() { if epoch_info.validator_account_id(*validator_id) == account_id { @@ -1458,16 +1466,18 @@ impl EpochManager { ValidatorInfoIdentifier::BlockHash(ref b) => self.get_epoch_id(b)?, }; let cur_epoch_info = self.get_epoch_info(&epoch_id)?; + let cur_shard_layout = self.get_shard_layout(&epoch_id)?; let epoch_height = cur_epoch_info.epoch_height(); let epoch_start_height = self.get_epoch_start_from_epoch_id(&epoch_id)?; let mut validator_to_shard = (0..cur_epoch_info.validators_len()) .map(|_| HashSet::default()) .collect::>>(); - for (shard_id, validators) in + for (shard_index, validators) in cur_epoch_info.chunk_producers_settlement().into_iter().enumerate() { + let shard_id = cur_shard_layout.get_shard_id(shard_index); for validator_id in validators { - validator_to_shard[*validator_id as usize].insert(shard_id as ShardId); + validator_to_shard[*validator_id as usize].insert(shard_id); } } @@ -1630,14 +1640,16 @@ impl EpochManager { }; let next_epoch_info = self.get_epoch_info(&next_epoch_id)?; + let next_shard_layout = self.get_shard_layout(&next_epoch_id)?; let mut next_validator_to_shard = (0..next_epoch_info.validators_len()) .map(|_| HashSet::default()) .collect::>>(); - for (shard_id, validators) in + for (shard_index, validators) in next_epoch_info.chunk_producers_settlement().iter().enumerate() { + let shard_id = next_shard_layout.get_shard_id(shard_index); for validator_id in validators { - next_validator_to_shard[*validator_id as usize].insert(shard_id as u64); + next_validator_to_shard[*validator_id as usize].insert(shard_id); } } let next_validators = next_epoch_info @@ -1742,10 +1754,11 @@ impl EpochManager { #[inline] pub(crate) fn chunk_producer_from_info( epoch_info: &EpochInfo, - height: BlockHeight, + shard_layout: &ShardLayout, shard_id: ShardId, + height: BlockHeight, ) -> Result { - epoch_info.sample_chunk_producer(height, shard_id).ok_or_else(|| { + epoch_info.sample_chunk_producer(shard_layout, shard_id, height).ok_or_else(|| { EpochError::ChunkProducerSelectionError(format!( "Invalid shard {shard_id} for height {height}" )) @@ -2034,6 +2047,7 @@ impl EpochManager { let epoch_id = *self.get_block_info(block_hash)?.epoch_id(); let epoch_info = self.get_epoch_info(&epoch_id)?; + let shard_layout = self.get_shard_layout(&epoch_id)?; let mut aggregator = EpochInfoAggregator::new(epoch_id, *block_hash); let mut cur_hash = *block_hash; @@ -2089,7 +2103,7 @@ impl EpochManager { }; let block_info = self.get_block_info(&cur_hash)?; - aggregator.update_tail(&block_info, &epoch_info, prev_height); + aggregator.update_tail(&block_info, &epoch_info, &shard_layout, prev_height); if prev_hash == self.epoch_info_aggregator.last_block_hash { // We’ve reached sync point of the old aggregator. If old diff --git a/chain/epoch-manager/src/shard_assignment.rs b/chain/epoch-manager/src/shard_assignment.rs index 39e88f64d21..53e421571a4 100644 --- a/chain/epoch-manager/src/shard_assignment.rs +++ b/chain/epoch-manager/src/shard_assignment.rs @@ -1,7 +1,8 @@ use crate::EpochInfo; use crate::RngSeed; use near_primitives::types::validator_stake::ValidatorStake; -use near_primitives::types::{Balance, NumShards, ShardId}; +use near_primitives::types::ShardIndex; +use near_primitives::types::{Balance, NumShards}; use near_primitives::utils::min_heap::{MinHeap, PeekMut}; use rand::Rng; use std::collections::{BTreeSet, HashMap, HashSet}; @@ -22,21 +23,48 @@ impl HasStake for ValidatorStake { } } +/// A helper struct to maintain the shard assignment sorted by the number of +/// validators assigned to each shard. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +struct ValidatorsFirstShardAssignmentItem { + validators: usize, + stake: Balance, + shard_index: ShardIndex, +} + +type ValidatorsFirstShardAssignment = MinHeap; + +/// A helper struct to maintain the shard assignment sorted by the stake +/// assigned to each shard. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +struct StakeFirstShardAssignmentItem { + stake: Balance, + validators: usize, + shard_index: ShardIndex, +} + +type StakeFirstShardAssignment = MinHeap; + +impl From for StakeFirstShardAssignmentItem { + fn from(v: ValidatorsFirstShardAssignmentItem) -> Self { + Self { validators: v.validators, stake: v.stake, shard_index: v.shard_index } + } +} + fn assign_to_satisfy_shards_inner>( - shard_index: &mut MinHeap<(usize, Balance, ShardId)>, + shard_assignment: &mut ValidatorsFirstShardAssignment, result: &mut Vec>, cp_iter: &mut I, min_validators_per_shard: usize, ) { - let mut buffer = Vec::with_capacity(shard_index.len()); - // Stores (shard_id, cp_index) meaning that cp at cp_index has already been - // added to shard shard_id. Used to make sure we don’t add a cp to the same + let mut buffer = Vec::with_capacity(shard_assignment.len()); + // Stores (shard_index, cp_index) meaning that cp at cp_index has already been + // added to shard shard_index. Used to make sure we don’t add a cp to the same // shard multiple times. - let mut seen = std::collections::HashSet::<(ShardId, usize)>::with_capacity( - result.len() * min_validators_per_shard, - ); + let seen_capacity = result.len() * min_validators_per_shard; + let mut seen = HashSet::<(ShardIndex, usize)>::with_capacity(seen_capacity); - while shard_index.peek().unwrap().0 < min_validators_per_shard { + while shard_assignment.peek().unwrap().validators < min_validators_per_shard { // cp_iter is an infinite cycle iterator so getting next value can never // fail. cp_index is index of each element in the iterator but the // indexing is done before cycling thus the same cp always gets the same @@ -45,26 +73,26 @@ fn assign_to_satisfy_shards_inner { // No shards left which don’t already contain this chunk // producer. Skip it and move to another producer. break; } - Some(top) if top.0 >= min_validators_per_shard => { - // `shard_index` is sorted by number of chunk producers, + Some(top) if top.validators >= min_validators_per_shard => { + // `shard_assignment` is sorted by number of chunk producers, // thus all remaining shards have min_validators_per_shard // producers already assigned to them. Don’t assign current // one to any shard and move to next cp. break; } - Some(mut top) if seen.insert((top.2, cp_index)) => { + Some(mut top) if seen.insert((top.shard_index, cp_index)) => { // Chunk producer is not yet assigned to the shard and the // shard still needs more producers. Assign `cp` to it and // move to next one. - top.0 += 1; - top.1 += cp.get_stake(); - result[usize::try_from(top.2).unwrap()].push(cp); + top.validators += 1; + top.stake += cp.get_stake(); + result[top.shard_index].push(cp); break; } Some(top) => { @@ -78,7 +106,7 @@ fn assign_to_satisfy_shards_inner( let mut result: Vec> = (0..num_shards).map(|_| Vec::new()).collect(); // Initially, sort by number of validators first so we fill shards up. - let mut shard_index: MinHeap<(usize, Balance, ShardId)> = - (0..num_shards).map(|s| (0, 0, s)).collect(); + let mut shard_assignment: ValidatorsFirstShardAssignment = (0..num_shards) + .map(|shard_index| shard_index as usize) + .map(|shard_index| ValidatorsFirstShardAssignmentItem { + validators: 0, + stake: 0, + shard_index, + }) + .collect(); // Distribute chunk producers until all shards have at least the // minimum requested number. If there are not enough validators to satisfy // that requirement, assign some of the validators to multiple shards. let mut chunk_producers = chunk_producers.into_iter().enumerate().cycle(); assign_to_satisfy_shards_inner( - &mut shard_index, + &mut shard_assignment, &mut result, &mut chunk_producers, min_validators_per_shard, @@ -153,7 +187,11 @@ struct ShardSetItem { /// /// Caller must guarantee that `min_validators_per_shard` is achievable and /// `prev_chunk_producers_assignment` corresponds to the same number of shards. +/// /// TODO(resharding) - implement shard assignment +/// The current shard assignment works fully based on the ShardIndex. During +/// resharding those indices will change and the assignment will move many +/// validators to different shards. This should be avoided. fn assign_to_balance_shards( chunk_producers: Vec, num_shards: NumShards, @@ -304,8 +342,12 @@ pub(crate) fn assign_chunk_producers_to_shards( pub(crate) mod old_validator_selection { use crate::shard_assignment::{assign_to_satisfy_shards_inner, HasStake, NotEnoughValidators}; - use near_primitives::types::{Balance, NumShards, ShardId}; - use near_primitives::utils::min_heap::MinHeap; + use near_primitives::types::NumShards; + + use super::{ + StakeFirstShardAssignment, StakeFirstShardAssignmentItem, ValidatorsFirstShardAssignment, + ValidatorsFirstShardAssignmentItem, + }; /// Assign chunk producers (a.k.a. validators) to shards. The i-th element /// of the output corresponds to the validators assigned to the i-th shard. @@ -344,15 +386,21 @@ pub(crate) mod old_validator_selection { let mut result: Vec> = (0..num_shards).map(|_| Vec::new()).collect(); // Initially, sort by number of validators first so we fill shards up. - let mut shard_index: MinHeap<(usize, Balance, ShardId)> = - (0..num_shards).map(|s| (0, 0, s)).collect(); + let mut shard_assignment: ValidatorsFirstShardAssignment = (0..num_shards) + .map(|shard_index| shard_index as usize) + .map(|shard_index| ValidatorsFirstShardAssignmentItem { + validators: 0, + stake: 0, + shard_index, + }) + .collect(); // First, distribute chunk producers until all shards have at least the // minimum requested number. If there are not enough validators to satisfy // that requirement, assign some of the validators to multiple shards. let mut chunk_producers = chunk_producers.into_iter().enumerate().cycle(); assign_to_satisfy_shards_inner( - &mut shard_index, + &mut shard_assignment, &mut result, &mut chunk_producers, min_validators_per_shard, @@ -364,20 +412,21 @@ pub(crate) mod old_validator_selection { num_chunk_producers.saturating_sub(num_shards as usize * min_validators_per_shard); if remaining_producers > 0 { // Re-index shards to favour lowest stake first. - let mut shard_index: MinHeap<(Balance, usize, ShardId)> = shard_index - .into_iter() - .map(|(count, stake, shard_id)| (stake, count, shard_id)) - .collect(); + let mut shard_assignment: StakeFirstShardAssignment = + shard_assignment.into_iter().map(Into::into).collect(); for (_, cp) in chunk_producers.take(remaining_producers) { - let (least_stake, least_validator_count, shard_id) = - shard_index.pop().expect("shard_index should never be empty"); - shard_index.push(( - least_stake + cp.get_stake(), - least_validator_count + 1, - shard_id, - )); - result[usize::try_from(shard_id).unwrap()].push(cp); + let StakeFirstShardAssignmentItem { + stake: least_stake, + validators: least_validator_count, + shard_index, + } = shard_assignment.pop().expect("shard_assignment should never be empty"); + shard_assignment.push(StakeFirstShardAssignmentItem { + stake: least_stake + cp.get_stake(), + validators: least_validator_count + 1, + shard_index, + }); + result[shard_index].push(cp); } } @@ -390,7 +439,7 @@ mod tests { use crate::shard_assignment::{assign_chunk_producers_to_shards, NotEnoughValidators}; use crate::RngSeed; use near_primitives::types::validator_stake::ValidatorStake; - use near_primitives::types::{AccountId, Balance, NumShards}; + use near_primitives::types::{AccountId, Balance, NumShards, ShardIndex}; use std::collections::{HashMap, HashSet}; const EXPONENTIAL_STAKES: [Balance; 12] = [100, 90, 81, 73, 66, 59, 53, 48, 43, 39, 35, 31]; @@ -486,12 +535,12 @@ mod tests { let mut assignments = assignments .into_iter() .enumerate() - .map(|(shard_id, cps)| { + .map(|(shard_index, cps)| { // All shards must have at least min_validators_per_shard validators. assert!( cps.len() >= min_validators_per_shard, "Shard {} has only {} chunk producers; expected at least {}", - shard_id, + shard_index, cps.len(), min_validators_per_shard ); @@ -500,7 +549,7 @@ mod tests { cps.len(), cps.iter().map(|cp| cp.0).collect::>().len(), "Shard {} contains duplicate chunk producers: {:?}", - shard_id, + shard_index, cps ); // If all is good, aggregate as (cps_count, total_stake) pair. @@ -520,12 +569,12 @@ mod tests { / (stakes.len() as Balance); let assignment = assign_shards(stakes, num_shards, min_validators_per_shard) .expect("There should have been enough validators"); - for (shard_id, &cps) in assignment.iter().enumerate() { + for (shard_index, &cps) in assignment.iter().enumerate() { // Validator distribution should be even. assert_eq!( validators_per_shard, cps.0, "Shard {} has {} validators, expected {}", - shard_id, cps.0, validators_per_shard + shard_index, cps.0, validators_per_shard ); // Stake distribution should be even @@ -533,7 +582,7 @@ mod tests { assert!( diff.abs() < diff_tolerance, "Shard {}'s stake {} is {} away from average; expected less than {} away", - shard_id, + shard_index, cps.1, diff.abs(), diff_tolerance @@ -724,12 +773,12 @@ mod tests { assert_eq!(assignment, target_assignment); } - fn validator_to_shard(assignment: &[Vec]) -> HashMap { + fn validator_to_shard(assignment: &[Vec]) -> HashMap { assignment .iter() .enumerate() - .flat_map(|(shard_id, cps)| { - cps.iter().map(move |cp| (cp.account_id().clone(), shard_id)) + .flat_map(|(shard_index, cps)| { + cps.iter().map(move |cp| (cp.account_id().clone(), shard_index)) }) .collect() } diff --git a/chain/epoch-manager/src/shard_tracker.rs b/chain/epoch-manager/src/shard_tracker.rs index 49cc1abf95c..59b2a8f593a 100644 --- a/chain/epoch-manager/src/shard_tracker.rs +++ b/chain/epoch-manager/src/shard_tracker.rs @@ -81,11 +81,13 @@ impl ShardTracker { shard_layout.shard_ids().map(|_| false).collect(); for account_id in tracked_accounts { let shard_id = account_id_to_shard_id(account_id, &shard_layout); - tracking_mask[shard_id as usize] = true; + let shard_index = shard_layout.get_shard_index(shard_id); + tracking_mask[shard_index] = true; } tracking_mask }); - Ok(tracking_mask.get(shard_id as usize).copied().unwrap_or(false)) + let shard_index = shard_layout.get_shard_index(shard_id); + Ok(tracking_mask.get(shard_index).copied().unwrap_or(false)) } TrackedConfig::AllShards => Ok(true), TrackedConfig::Schedule(schedule) => { @@ -209,13 +211,16 @@ mod tests { use crate::shard_tracker::TrackedConfig; use crate::test_utils::hash_range; use crate::{EpochManager, EpochManagerAdapter, EpochManagerHandle, RewardCalculator}; + use itertools::Itertools; use near_crypto::{KeyType, PublicKey}; use near_primitives::epoch_block_info::BlockInfo; use near_primitives::epoch_manager::{AllEpochConfig, EpochConfig}; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardLayout; use near_primitives::types::validator_stake::ValidatorStake; - use near_primitives::types::{BlockHeight, EpochId, NumShards, ProtocolVersion, ShardId}; + use near_primitives::types::{ + new_shard_id_tmp, BlockHeight, EpochId, NumShards, ProtocolVersion, ShardId, + }; use near_primitives::version::ProtocolFeature::SimpleNightshade; use near_primitives::version::PROTOCOL_VERSION; use near_store::test_utils::create_test_store; @@ -334,7 +339,7 @@ mod tests { #[test] fn test_track_accounts() { - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids = (0..4).map(new_shard_id_tmp).collect_vec(); let epoch_manager = get_epoch_manager(PROTOCOL_VERSION, shard_ids.len() as NumShards, false); let shard_layout = epoch_manager.read().get_shard_layout(&EpochId::default()).unwrap(); @@ -359,7 +364,7 @@ mod tests { #[test] fn test_track_all_shards() { - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids = (0..4).map(new_shard_id_tmp).collect_vec(); let epoch_manager = get_epoch_manager(PROTOCOL_VERSION, shard_ids.len() as NumShards, false); let tracker = ShardTracker::new(TrackedConfig::AllShards, Arc::new(epoch_manager)); @@ -378,17 +383,21 @@ mod tests { #[test] fn test_track_schedule() { // Creates a ShardTracker that changes every epoch tracked shards. - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids = (0..4).map(new_shard_id_tmp).collect_vec(); + let epoch_manager = Arc::new(get_epoch_manager(PROTOCOL_VERSION, shard_ids.len() as NumShards, false)); - let subset1 = HashSet::from([0, 1]); - let subset2 = HashSet::from([1, 2]); - let subset3 = HashSet::from([2, 3]); + let subset1: HashSet = + HashSet::from([0, 1]).into_iter().map(new_shard_id_tmp).collect(); + let subset2: HashSet = + HashSet::from([1, 2]).into_iter().map(new_shard_id_tmp).collect(); + let subset3: HashSet = + HashSet::from([2, 3]).into_iter().map(new_shard_id_tmp).collect(); let tracker = ShardTracker::new( TrackedConfig::Schedule(vec![ subset1.clone().into_iter().collect(), - subset2.clone().into_iter().collect(), - subset3.clone().into_iter().collect(), + subset2.clone().into_iter().map(Into::into).collect(), + subset3.clone().into_iter().map(Into::into).collect(), ]), epoch_manager.clone(), ); diff --git a/chain/epoch-manager/src/tests/mod.rs b/chain/epoch-manager/src/tests/mod.rs index 6bc992de46c..80542627ec5 100644 --- a/chain/epoch-manager/src/tests/mod.rs +++ b/chain/epoch-manager/src/tests/mod.rs @@ -26,6 +26,7 @@ use near_primitives::stateless_validation::partial_witness::PartialEncodedStateW use near_primitives::types::ValidatorKickoutReason::{ NotEnoughBlocks, NotEnoughChunkEndorsements, NotEnoughChunks, }; +use near_primitives::types::{new_shard_id_tmp, ShardIndex}; use near_primitives::validator_signer::ValidatorSigner; use near_primitives::version::ProtocolFeature::{self, SimpleNightshade}; use near_primitives::version::PROTOCOL_VERSION; @@ -882,12 +883,13 @@ fn test_reward_multiple_shards() { for height in 1..(2 * epoch_length) { let i = height as usize; let epoch_id = epoch_manager.get_epoch_id_from_prev_block(&h[i - 1]).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id).unwrap(); // test1 skips its chunks in the first epoch let chunk_mask = (0..num_shards) .map(|shard_index| { - let expected_chunk_producer = epoch_manager - .get_chunk_producer_info(&epoch_id, height, shard_index as u64) - .unwrap(); + let shard_id = shard_layout.get_shard_id(shard_index as ShardIndex); + let expected_chunk_producer = + epoch_manager.get_chunk_producer_info(&epoch_id, height, shard_id).unwrap(); if expected_chunk_producer.account_id() == "test1" && epoch_id == init_epoch_id { expected_chunks += 1; false @@ -1092,11 +1094,17 @@ fn test_expected_chunks_prev_block_not_produced() { let height = i as u64; let epoch_id = epoch_manager.get_epoch_id_from_prev_block(&prev_block).unwrap(); let epoch_info = epoch_manager.get_epoch_info(&epoch_id).unwrap().clone(); + let shard_layout = epoch_manager.get_shard_layout(&epoch_id).unwrap(); let block_producer = EpochManager::block_producer_from_info(&epoch_info, height); let prev_block_info = epoch_manager.get_block_info(&prev_block).unwrap(); let prev_height = prev_block_info.height(); - let expected_chunk_producer = - EpochManager::chunk_producer_from_info(&epoch_info, prev_height + 1, 0).unwrap(); + let expected_chunk_producer = EpochManager::chunk_producer_from_info( + &epoch_info, + &shard_layout, + new_shard_id_tmp(0), + prev_height + 1, + ) + .unwrap(); // test1 does not produce blocks during first epoch if block_producer == 0 && epoch_id == initial_epoch_id { expected += 1; @@ -1491,15 +1499,20 @@ fn test_chunk_producer_kickout() { let height = height as u64; let epoch_id = em.get_epoch_id_from_prev_block(prev_block).unwrap(); let epoch_info = em.get_epoch_info(&epoch_id).unwrap().clone(); + let shard_layout = em.get_shard_layout(&epoch_id).unwrap(); let chunk_mask = (0..4) - .map(|shard_id| { + .map(|shard_index| { if height >= epoch_length { return true; } - - let chunk_producer = - EpochManager::chunk_producer_from_info(&epoch_info, height, shard_id as u64) - .unwrap(); + let shard_id = shard_layout.get_shard_id(shard_index); + let chunk_producer = EpochManager::chunk_producer_from_info( + &epoch_info, + &shard_layout, + shard_id, + height, + ) + .unwrap(); // test1 skips chunks if chunk_producer == 0 { expected += 1; @@ -1636,20 +1649,24 @@ fn test_chunk_validator_kickout_using_endorsement_stats() { for (prev_block, (height, curr_block)) in hashes.iter().zip(hashes.iter().enumerate().skip(1)) { let height = height as u64; let epoch_id = em.get_epoch_id_from_prev_block(prev_block).unwrap(); + let shard_layout = em.get_shard_layout(&epoch_id).unwrap(); // All chunks are produced. let chunk_mask = vec![true; num_shards as usize]; // Prepare the chunk endorsements so that "test2" misses some of the endorsements. let mut bitmap = ChunkEndorsementsBitmap::new(num_shards as usize); - for shard_id in 0..num_shards { + for shard_id in shard_layout.shard_ids() { let chunk_validators = em .get_chunk_validator_assignments(&epoch_id, shard_id, height) .unwrap() .ordered_chunk_validators(); + let shard_index = shard_layout.get_shard_index(shard_id); bitmap.add_endorsements( - shard_id, + shard_index, chunk_validators .iter() - .map(|account| account.as_str() != "test2" || (height + shard_id) % 2 == 0) + .map(|account| { + account.as_str() != "test2" || (height + shard_index as u64) % 2 == 0 + }) .collect(), ) } @@ -2584,13 +2601,13 @@ fn test_validator_kickout_determinism() { (4, ChunkStats::new_with_endorsement(89, 100)), ]); let chunk_stats_tracker1 = HashMap::from([ - (0, chunk_stats0.clone().into_iter().collect()), - (1, chunk_stats1.clone().into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.clone().into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.clone().into_iter().collect()), ]); let chunk_stats0: Vec<_> = chunk_stats0.into_iter().rev().collect(); let chunk_stats_tracker2 = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts1) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2653,8 +2670,8 @@ fn test_chunk_validators_with_different_endorsement_ratio() { (3, ChunkStats::new_with_endorsement(60, 100)), ]); let chunk_stats_tracker = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2715,8 +2732,8 @@ fn test_chunk_validators_with_same_endorsement_ratio_and_different_stake() { (3, ChunkStats::new_with_endorsement(65, 100)), ]); let chunk_stats_tracker = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2777,8 +2794,8 @@ fn test_chunk_validators_with_same_endorsement_ratio_and_stake() { (3, ChunkStats::new_with_endorsement(65, 100)), ]); let chunk_stats_tracker = HashMap::from([ - (0, chunk_stats0.into_iter().collect()), - (1, chunk_stats1.into_iter().collect()), + (new_shard_id_tmp(0), chunk_stats0.into_iter().collect()), + (new_shard_id_tmp(1), chunk_stats1.into_iter().collect()), ]); let (_validator_stats, kickouts) = EpochManager::compute_validators_to_reward_and_kickout( &epoch_config, @@ -2826,7 +2843,7 @@ fn test_validator_kickout_sanity() { ]); let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new_with_production(100, 100)), ( @@ -2844,7 +2861,7 @@ fn test_validator_kickout_sanity() { ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new_with_production(70, 100)), ( @@ -2964,7 +2981,7 @@ fn test_chunk_endorsement_stats() { ]), &HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new(100, 100, 100, 100)), (1, ChunkStats::new(90, 100, 100, 100)), @@ -2973,7 +2990,7 @@ fn test_chunk_endorsement_stats() { ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new(95, 100, 100, 100)), (1, ChunkStats::new(95, 100, 90, 100)), @@ -3043,16 +3060,16 @@ fn test_max_kickout_stake_ratio() { // validator 3 doesn't need to produce any block or chunk (3, ValidatorStats { produced: 0, expected: 0 }), ]); - let chunk_stats = HashMap::from([ + let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new_with_production(0, 100)), (1, ChunkStats::new_with_production(0, 100)), ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (2, ChunkStats::new_with_production(100, 100)), (4, ChunkStats::new_with_production(50, 100)), @@ -3065,7 +3082,7 @@ fn test_max_kickout_stake_ratio() { &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &prev_validator_kickout, ); @@ -3125,7 +3142,7 @@ fn test_max_kickout_stake_ratio() { &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &prev_validator_kickout, ); @@ -3173,9 +3190,9 @@ fn test_chunk_validator_kickout( (2, ValidatorStats { produced: 90, expected: 100 }), (3, ValidatorStats { produced: 0, expected: 0 }), ]); - let chunk_stats = HashMap::from([ + let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new_with_production(90, 100)), (1, ChunkStats::new_with_production(90, 100)), @@ -3185,7 +3202,7 @@ fn test_chunk_validator_kickout( ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new_with_production(90, 100)), (2, ChunkStats::new_with_production(90, 100)), @@ -3204,7 +3221,7 @@ fn test_chunk_validator_kickout( &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &prev_validator_kickout, ); @@ -3251,9 +3268,9 @@ fn test_block_and_chunk_producer_not_kicked_out_for_low_endorsements() { (1, ValidatorStats { produced: 90, expected: 100 }), (2, ValidatorStats { produced: 90, expected: 100 }), ]); - let chunk_stats = HashMap::from([ + let chunk_stats_tracker = HashMap::from([ ( - 0, + new_shard_id_tmp(0), HashMap::from([ (0, ChunkStats::new(90, 100, 10, 100)), (1, ChunkStats::new(90, 100, 10, 100)), @@ -3261,7 +3278,7 @@ fn test_block_and_chunk_producer_not_kicked_out_for_low_endorsements() { ]), ), ( - 1, + new_shard_id_tmp(1), HashMap::from([ (0, ChunkStats::new(90, 100, 10, 100)), (1, ChunkStats::new(90, 100, 10, 100)), @@ -3276,7 +3293,7 @@ fn test_block_and_chunk_producer_not_kicked_out_for_low_endorsements() { &epoch_config, &epoch_info, &block_stats, - &chunk_stats, + &chunk_stats_tracker, &HashMap::new(), &HashMap::new(), ); @@ -3295,7 +3312,7 @@ fn test_chunk_header(h: &[CryptoHash], signer: &ValidatorSigner) -> ShardChunkHe h[2], 0, 1, - 0, + new_shard_id_tmp(0), 0, 0, 0, @@ -3329,7 +3346,7 @@ fn test_verify_chunk_endorsements() { // verify if we have one chunk validator let chunk_validator_assignments = - &epoch_manager.get_chunk_validator_assignments(&epoch_id, 0, 1).unwrap(); + &epoch_manager.get_chunk_validator_assignments(&epoch_id, new_shard_id_tmp(0), 1).unwrap(); assert_eq!(chunk_validator_assignments.ordered_chunk_validators().len(), 1); assert!(chunk_validator_assignments.contains(&account_id)); diff --git a/chain/epoch-manager/src/tests/random_epochs.rs b/chain/epoch-manager/src/tests/random_epochs.rs index 8d8bdc4f329..ac0a46b0687 100644 --- a/chain/epoch-manager/src/tests/random_epochs.rs +++ b/chain/epoch-manager/src/tests/random_epochs.rs @@ -325,7 +325,8 @@ fn verify_block_stats( { let aggregator = epoch_manager.get_epoch_info_aggregator_upto_last(&block_hashes[i]).unwrap(); - let epoch_info = epoch_manager.get_epoch_info(block_infos[i].epoch_id()).unwrap(); + let epoch_id = block_infos[i].epoch_id(); + let epoch_info = epoch_manager.get_epoch_info(epoch_id).unwrap(); for key in aggregator.block_tracker.keys().copied() { assert!(key < epoch_info.validators_iter().len() as u64); } @@ -340,7 +341,10 @@ fn verify_block_stats( aggregator.block_tracker.values().map(|value| value.expected).sum::(); assert_eq!(sum_produced, blocks_in_epoch); assert_eq!(sum_expected, blocks_in_epoch_expected); - for shard_id in 0..(aggregator.shard_tracker.len() as u64) { + // TODO: The following sophisticated check doesn't do anything. The + // shard tracker is empty because the chunk mask in all block infos + // is empty. + for &shard_id in aggregator.shard_tracker.keys() { let sum_produced = aggregator .shard_tracker .get(&shard_id) diff --git a/chain/epoch-manager/src/types.rs b/chain/epoch-manager/src/types.rs index 22cda704587..aebac696fb7 100644 --- a/chain/epoch-manager/src/types.rs +++ b/chain/epoch-manager/src/types.rs @@ -3,6 +3,7 @@ use itertools::Itertools; use near_primitives::epoch_block_info::BlockInfo; use near_primitives::epoch_info::EpochInfo; use near_primitives::hash::CryptoHash; +use near_primitives::shard_layout::ShardLayout; use near_primitives::types::validator_stake::ValidatorStake; use near_primitives::types::{ AccountId, BlockHeight, ChunkStats, EpochId, ShardId, ValidatorId, ValidatorStats, @@ -69,6 +70,7 @@ impl EpochInfoAggregator { &mut self, block_info: &BlockInfo, epoch_info: &EpochInfo, + shard_layout: &ShardLayout, prev_block_height: BlockHeight, ) { let _span = @@ -105,12 +107,13 @@ impl EpochInfoAggregator { // TODO(#11900): Call EpochManager::get_chunk_validator_assignments to access the cached validator assignments. let chunk_validator_assignment = epoch_info.sample_chunk_validators(prev_block_height + 1); - for (i, mask) in block_info.chunk_mask().iter().enumerate() { - let shard_id: ShardId = i as ShardId; + for (shard_index, mask) in block_info.chunk_mask().iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let chunk_producer_id = EpochManager::chunk_producer_from_info( epoch_info, + shard_layout, + shard_id, prev_block_height + 1, - i as ShardId, ) .unwrap(); let tracker = self.shard_tracker.entry(shard_id).or_insert_with(HashMap::new); @@ -123,7 +126,7 @@ impl EpochInfoAggregator { debug!( target: "epoch_tracker", chunk_validator = ?epoch_info.validator_account_id(chunk_producer_id), - shard_id = i, + ?shard_id, block_height = prev_block_height + 1, "Missed chunk"); } @@ -132,7 +135,7 @@ impl EpochInfoAggregator { .or_insert_with(|| ChunkStats::new_with_production(u64::from(*mask), 1)); let chunk_validators = chunk_validator_assignment - .get(i) + .get(shard_index) .map_or::<&[(u64, u128)], _>(&[], Vec::as_slice) .iter() .map(|(id, _)| *id) @@ -148,14 +151,14 @@ impl EpochInfoAggregator { // For old chunks, we optimize the block and its header by not including the chunk endorsements and // corresponding bitmaps. Thus, we expect that the bitmap is non-empty for new chunks only. if *mask { - debug_assert!(chunk_endorsements.len(shard_id).unwrap() == chunk_validators.len().div_ceil(8) * 8, - "Chunk endorsement bitmap length is inconsistent with number of chunk validators. Bitmap length={}, num validators={}, shard_id={}", - chunk_endorsements.len(shard_id).unwrap(), chunk_validators.len(), shard_id); - chunk_endorsements.iter(shard_id) + debug_assert!(chunk_endorsements.len(shard_index).unwrap() == chunk_validators.len().div_ceil(8) * 8, + "Chunk endorsement bitmap length is inconsistent with number of chunk validators. Bitmap length={}, num validators={}, shard_index={}", + chunk_endorsements.len(shard_index).unwrap(), chunk_validators.len(), shard_index); + chunk_endorsements.iter(shard_index) } else { - debug_assert_eq!(chunk_endorsements.len(shard_id).unwrap(), 0, - "Chunk endorsement bitmap must be empty for missing chunk. Bitmap length={}, shard_id={}", - chunk_endorsements.len(shard_id).unwrap(), shard_id); + debug_assert_eq!(chunk_endorsements.len(shard_index).unwrap(), 0, + "Chunk endorsement bitmap must be empty for missing chunk. Bitmap length={}, shard_index={}", + chunk_endorsements.len(shard_index).unwrap(), shard_index); Box::new(std::iter::repeat(false).take(chunk_validators.len())) } } else { diff --git a/chain/epoch-manager/src/validator_selection.rs b/chain/epoch-manager/src/validator_selection.rs index dc9716ff71b..0e58c0221b9 100644 --- a/chain/epoch-manager/src/validator_selection.rs +++ b/chain/epoch-manager/src/validator_selection.rs @@ -591,12 +591,12 @@ mod old_validator_selection { all_validators.push(bp.clone()); } - let shard_ids: Vec<_> = epoch_config.shard_layout.shard_ids().collect(); + let num_shards = epoch_config.shard_layout.shard_ids().count(); if chunk_producers.is_empty() { // All validators tried to unstake? return Err(EpochError::NotEnoughValidators { num_validators: 0u64, - num_shards: shard_ids.len() as NumShards, + num_shards: num_shards as u64, }); } @@ -605,11 +605,9 @@ mod old_validator_selection { // each validator as even as possible). Note that in prod configuration number of seats // per shard is the same as maximal number of block producers, so normally all // validators would be assigned to all chunks - let chunk_producers_settlement = shard_ids - .iter() - .map(|&shard_id| shard_id as usize) - .map(|shard_id| { - (0..epoch_config.num_block_producer_seats_per_shard[shard_id] + let chunk_producers_settlement = (0..num_shards) + .map(|shard_index| { + (0..epoch_config.num_block_producer_seats_per_shard[shard_index] .min(block_producers_settlement.len() as u64)) .map(|_| { let res = block_producers_settlement[id]; @@ -637,6 +635,7 @@ mod tests { use near_primitives::epoch_manager::ValidatorSelectionConfig; use near_primitives::shard_layout::ShardLayout; use near_primitives::types::validator_stake::ValidatorStake; + use near_primitives::types::ShardIndex; use near_primitives::version::PROTOCOL_VERSION; use num_rational::Ratio; @@ -964,12 +963,15 @@ mod tests { ) .unwrap(); - for shard_id in 0..num_shards { + let shard_layout = &epoch_config.shard_layout; + for shard_index in 0..num_shards { + let shard_index = shard_index as ShardIndex; + let shard_id = shard_layout.get_shard_id(shard_index); for h in 0..100_000 { - let cp = epoch_info.sample_chunk_producer(h, shard_id); + let cp = epoch_info.sample_chunk_producer(shard_layout, shard_id, h); // Don't read too much into this. The reason the ValidatorId always // equals the ShardId is because the validators are assigned to shards in order. - assert_eq!(cp, Some(shard_id)) + assert_eq!(cp, Some(shard_index as u64)) } } @@ -992,10 +994,10 @@ mod tests { ) .unwrap(); - for shard_id in 0..num_shards { + for shard_id in shard_layout.shard_ids() { let mut counts: [i32; 2] = [0, 0]; for h in 0..100_000 { - let cp = epoch_info.sample_chunk_producer(h, shard_id).unwrap(); + let cp = epoch_info.sample_chunk_producer(shard_layout, shard_id, h).unwrap(); // if ValidatorId is in the second half then it is the lower // stake validator (because they are sorted by decreasing stake). let index = if cp >= num_shards { 1 } else { 0 }; diff --git a/chain/indexer/src/streamer/mod.rs b/chain/indexer/src/streamer/mod.rs index 3e2c2dbb0cc..45b01ccd042 100644 --- a/chain/indexer/src/streamer/mod.rs +++ b/chain/indexer/src/streamer/mod.rs @@ -79,8 +79,7 @@ pub async fn build_streamer_message( let chunks = fetch_block_chunks(&client, &block).await?; let protocol_config_view = fetch_protocol_config(&client, block.header.hash).await?; - let num_shards = protocol_config_view.num_block_producer_seats_per_shard.len() - as near_primitives::types::NumShards; + let shard_ids = protocol_config_view.shard_layout.shard_ids(); let runtime_config_store = near_parameters::RuntimeConfigStore::new(None); let runtime_config = runtime_config_store.get_config(protocol_config_view.protocol_version); @@ -92,7 +91,7 @@ pub async fn build_streamer_message( near_primitives::types::EpochId(block.header.epoch_id), ) .await?; - let mut indexer_shards = (0..num_shards) + let mut indexer_shards = shard_ids .map(|shard_id| IndexerShard { shard_id, chunk: None, @@ -101,12 +100,10 @@ pub async fn build_streamer_message( }) .collect::>(); - for chunk in chunks { + for (shard_index, chunk) in chunks.into_iter().enumerate() { let views::ChunkView { transactions, author, header, receipts: chunk_non_local_receipts } = chunk; - let shard_id = header.shard_id as usize; - let mut outcomes = shards_outcomes .remove(&header.shard_id) .expect("Execution outcomes for given shard should be present"); @@ -236,9 +233,9 @@ pub async fn build_streamer_message( chunk_receipts.extend(chunk_non_local_receipts); - indexer_shards[shard_id].receipt_execution_outcomes = receipt_execution_outcomes; + indexer_shards[shard_index].receipt_execution_outcomes = receipt_execution_outcomes; // Put the chunk into corresponding indexer shard - indexer_shards[shard_id].chunk = Some(IndexerChunkView { + indexer_shards[shard_index].chunk = Some(IndexerChunkView { author, header, transactions: indexer_transactions, @@ -250,12 +247,13 @@ pub async fn build_streamer_message( // chunks and we end up with non-empty `shards_outcomes` we want to be sure we put them into IndexerShard // That might happen before the fix https://github.com/near/nearcore/pull/4228 for (shard_id, outcomes) in shards_outcomes { - indexer_shards[shard_id as usize].receipt_execution_outcomes.extend( - outcomes.into_iter().map(|outcome| IndexerExecutionOutcomeWithReceipt { + let shard_index = protocol_config_view.shard_layout.get_shard_index(shard_id); + indexer_shards[shard_index].receipt_execution_outcomes.extend(outcomes.into_iter().map( + |outcome| IndexerExecutionOutcomeWithReceipt { execution_outcome: outcome.execution_outcome, receipt: outcome.receipt.expect("`receipt` must be present at this moment"), - }), - ) + }, + )) } Ok(StreamerMessage { block, shards: indexer_shards }) diff --git a/chain/jsonrpc-primitives/src/types/chunks.rs b/chain/jsonrpc-primitives/src/types/chunks.rs index d571e5ea32f..de5ebcff138 100644 --- a/chain/jsonrpc-primitives/src/types/chunks.rs +++ b/chain/jsonrpc-primitives/src/types/chunks.rs @@ -34,6 +34,7 @@ pub enum RpcChunkError { #[serde(skip_serializing)] error_message: String, }, + // TODO Should use ShardId instead of u64 #[error("Shard id {shard_id} does not exist")] InvalidShardId { shard_id: u64 }, #[error("Chunk with hash {chunk_hash:?} has never been observed on this node")] diff --git a/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs b/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs index dc34436601a..727febd5987 100644 --- a/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs +++ b/chain/jsonrpc/jsonrpc-tests/tests/rpc_query.rs @@ -15,7 +15,7 @@ use near_network::test_utils::wait_or_timeout; use near_o11y::testonly::init_test_logger; use near_primitives::account::{AccessKey, AccessKeyPermission}; use near_primitives::hash::CryptoHash; -use near_primitives::types::{BlockId, BlockReference, EpochId, SyncCheckpoint}; +use near_primitives::types::{new_shard_id_tmp, BlockId, BlockReference, EpochId, SyncCheckpoint}; use near_primitives::views::QueryRequest; use near_time::Clock; @@ -90,7 +90,10 @@ fn test_block_query() { #[test] fn test_chunk_by_hash() { test_with_client!(test_utils::NodeType::NonValidator, client, async move { - let chunk = client.chunk(ChunkId::BlockShardId(BlockId::Height(0), 0u64)).await.unwrap(); + let chunk = client + .chunk(ChunkId::BlockShardId(BlockId::Height(0), new_shard_id_tmp(0))) + .await + .unwrap(); assert_eq!(chunk.author, "test1"); assert_eq!(chunk.header.balance_burnt, 0); assert_eq!(chunk.header.chunk_hash.as_ref().len(), 32); @@ -104,7 +107,7 @@ fn test_chunk_by_hash() { assert_eq!(chunk.header.prev_block_hash.as_ref().len(), 32); assert_eq!(chunk.header.prev_state_root.as_ref().len(), 32); assert_eq!(chunk.header.rent_paid, 0); - assert_eq!(chunk.header.shard_id, 0); + assert_eq!(chunk.header.shard_id, new_shard_id_tmp(0)); assert!(if let Signature::ED25519(_) = chunk.header.signature { true } else { false }); assert_eq!(chunk.header.tx_root.as_ref(), &[0; 32]); assert_eq!(chunk.header.validator_proposals, vec![]); @@ -118,7 +121,8 @@ fn test_chunk_by_hash() { #[test] fn test_chunk_invalid_shard_id() { test_with_client!(test_utils::NodeType::NonValidator, client, async move { - let chunk = client.chunk(ChunkId::BlockShardId(BlockId::Height(0), 100)).await; + let chunk = + client.chunk(ChunkId::BlockShardId(BlockId::Height(0), new_shard_id_tmp(100))).await; match chunk { Ok(_) => panic!("should result in an error"), Err(e) => { @@ -649,7 +653,7 @@ fn test_get_chunk_with_object_in_params() { assert_eq!(chunk.header.prev_block_hash.as_ref().len(), 32); assert_eq!(chunk.header.prev_state_root.as_ref().len(), 32); assert_eq!(chunk.header.rent_paid, 0); - assert_eq!(chunk.header.shard_id, 0); + assert_eq!(chunk.header.shard_id, new_shard_id_tmp(0)); assert!(if let Signature::ED25519(_) = chunk.header.signature { true } else { false }); assert_eq!(chunk.header.tx_root.as_ref(), &[0; 32]); assert_eq!(chunk.header.validator_proposals, vec![]); diff --git a/chain/jsonrpc/src/api/chunks.rs b/chain/jsonrpc/src/api/chunks.rs index badf0bf4654..b4034c01ba1 100644 --- a/chain/jsonrpc/src/api/chunks.rs +++ b/chain/jsonrpc/src/api/chunks.rs @@ -57,7 +57,9 @@ impl RpcFrom for RpcChunkError { match error { GetChunkError::IOError { error_message } => Self::InternalError { error_message }, GetChunkError::UnknownBlock { error_message } => Self::UnknownBlock { error_message }, - GetChunkError::InvalidShardId { shard_id } => Self::InvalidShardId { shard_id }, + GetChunkError::InvalidShardId { shard_id } => { + Self::InvalidShardId { shard_id: shard_id.into() } + } GetChunkError::UnknownChunk { chunk_hash } => Self::UnknownChunk { chunk_hash }, GetChunkError::Unreachable { ref error_message } => { tracing::warn!(target: "jsonrpc", "Unreachable error occurred: {}", error_message); diff --git a/chain/network/src/network_protocol/proto_conv/handshake.rs b/chain/network/src/network_protocol/proto_conv/handshake.rs index 11003597f14..7dc67077de8 100644 --- a/chain/network/src/network_protocol/proto_conv/handshake.rs +++ b/chain/network/src/network_protocol/proto_conv/handshake.rs @@ -42,7 +42,7 @@ impl From<&PeerChainInfoV2> for proto::PeerChainInfo { Self { genesis_id: MF::some((&x.genesis_id).into()), height: x.height, - tracked_shards: x.tracked_shards.clone(), + tracked_shards: x.tracked_shards.clone().into_iter().map(Into::into).collect(), archival: x.archival, ..Self::default() } @@ -55,7 +55,7 @@ impl TryFrom<&proto::PeerChainInfo> for PeerChainInfoV2 { Ok(Self { genesis_id: try_from_required(&p.genesis_id).map_err(Self::Error::GenesisId)?, height: p.height, - tracked_shards: p.tracked_shards.clone(), + tracked_shards: p.tracked_shards.clone().into_iter().map(Into::into).collect(), archival: p.archival, }) } diff --git a/chain/network/src/network_protocol/proto_conv/peer_message.rs b/chain/network/src/network_protocol/proto_conv/peer_message.rs index b73a66d7966..78e732c4a63 100644 --- a/chain/network/src/network_protocol/proto_conv/peer_message.rs +++ b/chain/network/src/network_protocol/proto_conv/peer_message.rs @@ -179,7 +179,7 @@ impl From<&SnapshotHostInfo> for proto::SnapshotHostInfo { peer_id: MF::some((&x.peer_id).into()), sync_hash: MF::some((&x.sync_hash).into()), epoch_height: x.epoch_height, - shards: x.shards.clone(), + shards: x.shards.clone().into_iter().map(Into::into).collect(), signature: MF::some((&x.signature).into()), ..Default::default() } @@ -193,7 +193,7 @@ impl TryFrom<&proto::SnapshotHostInfo> for SnapshotHostInfo { peer_id: try_from_required(&x.peer_id).map_err(Self::Error::PeerId)?, sync_hash: try_from_required(&x.sync_hash).map_err(Self::Error::SyncHash)?, epoch_height: x.epoch_height, - shards: x.shards.clone(), + shards: x.shards.clone().into_iter().map(Into::into).collect(), signature: try_from_required(&x.signature).map_err(Self::Error::Signature)?, }) } @@ -313,14 +313,14 @@ impl From<&PeerMessage> for proto::PeerMessage { PeerMessage::SyncSnapshotHosts(ssh) => ProtoMT::SyncSnapshotHosts(ssh.into()), PeerMessage::StateRequestHeader(shard_id, sync_hash) => { ProtoMT::StateRequestHeader(proto::StateRequestHeader { - shard_id: *shard_id, + shard_id: (*shard_id).into(), sync_hash: MF::some(sync_hash.into()), ..Default::default() }) } PeerMessage::StateRequestPart(shard_id, sync_hash, part_id) => { ProtoMT::StateRequestPart(proto::StateRequestPart { - shard_id: *shard_id, + shard_id: (*shard_id).into(), sync_hash: MF::some(sync_hash.into()), part_id: *part_id, ..Default::default() @@ -477,11 +477,11 @@ impl TryFrom<&proto::PeerMessage> for PeerMessage { Challenge::try_from_slice(&c.borsh).map_err(Self::Error::Challenge)?, ), ProtoMT::StateRequestHeader(srh) => PeerMessage::StateRequestHeader( - srh.shard_id, + srh.shard_id.into(), try_from_required(&srh.sync_hash).map_err(Self::Error::BlockRequest)?, ), ProtoMT::StateRequestPart(srp) => PeerMessage::StateRequestPart( - srp.shard_id, + srp.shard_id.into(), try_from_required(&srp.sync_hash).map_err(Self::Error::BlockRequest)?, srp.part_id, ), diff --git a/chain/network/src/network_protocol/testonly.rs b/chain/network/src/network_protocol/testonly.rs index 7a2762c61b6..6a34083e3ef 100644 --- a/chain/network/src/network_protocol/testonly.rs +++ b/chain/network/src/network_protocol/testonly.rs @@ -18,7 +18,7 @@ use near_primitives::sharding::{ ChunkHash, EncodedShardChunkBody, PartialEncodedChunkPart, ShardChunk, }; use near_primitives::transaction::SignedTransaction; -use near_primitives::types::{AccountId, BlockHeight, EpochId, StateRoot}; +use near_primitives::types::{new_shard_id_tmp, AccountId, BlockHeight, EpochId, StateRoot}; use near_primitives::validator_signer::{InMemoryValidatorSigner, ValidatorSigner}; use near_primitives::version; use rand::distributions::Standard; @@ -211,7 +211,7 @@ impl ChunkSet { Self { chunks: HashMap::default() } } pub fn make(&mut self) -> Vec { - let shard_ids: Vec<_> = (0..4).collect(); + let shard_ids: Vec<_> = (0..4).into_iter().map(new_shard_id_tmp).collect(); // TODO: these are always genesis chunks. // Consider making this more realistic. let chunks = genesis_chunks( diff --git a/chain/network/src/peer_manager/peer_manager_actor.rs b/chain/network/src/peer_manager/peer_manager_actor.rs index 9bfbf775f1c..b4739b85fa1 100644 --- a/chain/network/src/peer_manager/peer_manager_actor.rs +++ b/chain/network/src/peer_manager/peer_manager_actor.rs @@ -1373,7 +1373,7 @@ impl actix::Handler for PeerManagerActor { peer_id: h.peer_id.clone(), sync_hash: h.sync_hash, epoch_height: h.epoch_height, - shards: h.shards.clone(), + shards: h.shards.clone().into_iter().map(Into::into).collect(), }) .collect::>(), }), diff --git a/chain/network/src/peer_manager/tests/snapshot_hosts.rs b/chain/network/src/peer_manager/tests/snapshot_hosts.rs index 1ea2afbfcba..09c56a18816 100644 --- a/chain/network/src/peer_manager/tests/snapshot_hosts.rs +++ b/chain/network/src/peer_manager/tests/snapshot_hosts.rs @@ -10,12 +10,14 @@ use crate::types::NetworkRequests; use crate::types::PeerManagerMessageRequest; use crate::types::PeerMessage; use crate::{network_protocol::testonly as data, peer::testonly::PeerHandle}; +use itertools::Itertools; use near_async::time; use near_crypto::SecretKey; use near_o11y::testonly::init_test_logger; use near_o11y::WithSpanContextExt; use near_primitives::hash::CryptoHash; use near_primitives::network::PeerId; +use near_primitives::types::new_shard_id_tmp; use near_primitives::types::EpochHeight; use near_primitives::types::ShardId; use peer_manager::testonly::FDS_PER_PEER; @@ -32,10 +34,10 @@ fn make_snapshot_host_info( rng: &mut impl Rng, ) -> Arc { let epoch_height: EpochHeight = rng.gen::(); - let max_shard_id: ShardId = 32; + let max_shard_id = 32; let shards_num: usize = rng.gen_range(1..16); - let mut shards: Vec = (0..max_shard_id).choose_multiple(rng, shards_num); - shards.sort(); + let shards = (0..max_shard_id).choose_multiple(rng, shards_num); + let shards = shards.into_iter().sorted().map(new_shard_id_tmp).collect(); let sync_hash = CryptoHash::hash_borsh(epoch_height); Arc::new(SnapshotHostInfo::new(peer_id.clone(), sync_hash, epoch_height, shards, secret_key)) } @@ -251,7 +253,7 @@ async fn too_many_shards_not_broadcast() { tracing::info!(target:"test", "Send an invalid SyncSnapshotHosts message from peer1. One of the host infos has more shard ids than allowed."); let too_many_shards: Vec = - (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).collect(); + (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).map(Into::into).collect(); let invalid_info = Arc::new(SnapshotHostInfo::new( peer1_config.node_id(), CryptoHash::hash_borsh(rng.gen::()), @@ -369,11 +371,12 @@ async fn large_shard_id_in_cache() { let peer1 = pm.start_inbound(chain.clone(), peer1_config.clone()).await.handshake(clock).await; tracing::info!(target:"test", "Send a SnapshotHostInfo message with very large shard ids."); + let max_shard_id: ShardId = ShardId::MAX; let big_shard_info = Arc::new(SnapshotHostInfo::new( peer1_config.node_id(), CryptoHash::hash_borsh(1234_u64), 1234, - vec![0, 1232232, ShardId::MAX - 1, ShardId::MAX], + vec![0, 1232232, max_shard_id - 1, max_shard_id].into_iter().map(Into::into).collect(), &peer1_config.node_key, )); @@ -419,7 +422,7 @@ async fn too_many_shards_truncate() { tracing::info!(target:"test", "Ask peer manager to send out an invalid SyncSnapshotHosts message. The info has more shard ids than allowed."); // Create a list of shards with twice as many shard ids as is allowed let too_many_shards: Vec = - (0..(2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64)).collect(); + (0..(2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64)).map(Into::into).collect(); let sync_hash = CryptoHash::hash_borsh(rng.gen::()); let epoch_height: EpochHeight = rng.gen(); @@ -442,9 +445,9 @@ async fn too_many_shards_truncate() { // The list of shards should contain MAX_SHARDS_PER_SNAPSHOT_HOST_INFO randomly sampled, unique shard ids taken from too_many_shards assert_eq!(info.shards.len(), MAX_SHARDS_PER_SNAPSHOT_HOST_INFO); - for shard_id in &info.shards { + for &shard_id in &info.shards { // Shard ids are taken from the original vector - assert!(*shard_id < 2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64); + assert!(shard_id < 2 * MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64); } // The shard_ids are sorted and unique (no two elements are equal, hence the < condition instead of <=) assert!(info.shards.windows(2).all(|twoelems| twoelems[0] < twoelems[1])); diff --git a/chain/network/src/raw/tests.rs b/chain/network/src/raw/tests.rs index f7601b06e5a..49b5adb00df 100644 --- a/chain/network/src/raw/tests.rs +++ b/chain/network/src/raw/tests.rs @@ -8,6 +8,7 @@ use near_crypto::{KeyType, SecretKey}; use near_o11y::testonly::init_test_logger; use near_primitives::hash::CryptoHash; use near_primitives::network::PeerId; +use near_primitives::types::new_shard_id_tmp; use std::sync::Arc; #[tokio::test] @@ -38,7 +39,7 @@ async fn test_raw_conn_pings() { &genesis_id.chain_id, genesis_id.hash, 0, - vec![0], + vec![new_shard_id_tmp(0)], time::Duration::SECOND, ) .await @@ -99,7 +100,7 @@ async fn test_raw_conn_state_parts() { &genesis_id.chain_id, genesis_id.hash, 0, - vec![0], + vec![new_shard_id_tmp(0)], time::Duration::SECOND, ) .await @@ -110,9 +111,13 @@ async fn test_raw_conn_state_parts() { // But the fake node simply ignores the block hash. let block_hash = CryptoHash::new(); for part_id in 0..num_parts { - conn.send_message(raw::DirectMessage::StateRequestPart(0, block_hash, part_id)) - .await - .unwrap(); + conn.send_message(raw::DirectMessage::StateRequestPart( + new_shard_id_tmp(0), + block_hash, + part_id, + )) + .await + .unwrap(); } let mut part_id_received = -1i64; @@ -174,7 +179,7 @@ async fn test_listener() { &genesis_id.chain_id, genesis_id.hash, 0, - vec![0], + vec![new_shard_id_tmp(0)], false, time::Duration::SECOND, ) diff --git a/chain/network/src/snapshot_hosts/tests.rs b/chain/network/src/snapshot_hosts/tests.rs index 79b010938a5..0ab4fda763d 100644 --- a/chain/network/src/snapshot_hosts/tests.rs +++ b/chain/network/src/snapshot_hosts/tests.rs @@ -5,12 +5,13 @@ use crate::snapshot_hosts::{priority_score, Config, SnapshotHostInfoError, Snaps use crate::testonly::assert_is_superset; use crate::testonly::{make_rng, AsSet as _}; use crate::types::SnapshotHostInfo; +use itertools::Itertools; use near_crypto::SecretKey; use near_o11y::testonly::init_test_logger; use near_primitives::hash::CryptoHash; use near_primitives::network::PeerId; -use near_primitives::types::EpochHeight; use near_primitives::types::ShardId; +use near_primitives::types::{new_shard_id_tmp, EpochHeight}; use rand::Rng; use std::collections::HashSet; use std::sync::Arc; @@ -52,17 +53,19 @@ async fn happy_path() { let cache = SnapshotHostsCache::new(config); assert_eq!(cache.get_hosts().len(), 0); // initially empty + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + // initial insert - let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, vec![0, 1, 2, 3], &key0)); - let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, vec![2], &key1)); + let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, sid_vec(&[0, 1, 2, 3]), &key0)); + let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, sid_vec(&[2]), &key1)); let res = cache.insert(vec![info0.clone(), info1.clone()]).await; assert_eq!([&info0, &info1].as_set(), unwrap(&res).as_set()); assert_eq!([&info0, &info1].as_set(), cache.get_hosts().iter().collect::>()); // second insert with various types of updates - let info0new = Arc::new(make_snapshot_host_info(&peer0, 124, vec![1, 3], &key0)); - let info1old = Arc::new(make_snapshot_host_info(&peer1, 122, vec![0, 1, 2, 3], &key1)); - let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, vec![2], &key2)); + let info0new = Arc::new(make_snapshot_host_info(&peer0, 124, sid_vec(&[1, 3]), &key0)); + let info1old = Arc::new(make_snapshot_host_info(&peer1, 122, sid_vec(&[0, 1, 2, 3]), &key1)); + let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, sid_vec(&[2]), &key2)); let res = cache.insert(vec![info0new.clone(), info1old.clone(), info2.clone()]).await; assert_eq!([&info0new, &info2].as_set(), unwrap(&res).as_set()); assert_eq!( @@ -86,8 +89,11 @@ async fn invalid_signature() { let config = Config { snapshot_hosts_cache_size: 100, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); - let info0_invalid_sig = Arc::new(make_snapshot_host_info(&peer0, 1, vec![0, 1, 2, 3], &key1)); - let info1 = Arc::new(make_snapshot_host_info(&peer1, 1, vec![0, 1, 2, 3], &key1)); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + + let shards = sid_vec(&[0, 1, 2, 3]); + let info0_invalid_sig = Arc::new(make_snapshot_host_info(&peer0, 1, shards.clone(), &key1)); + let info1 = Arc::new(make_snapshot_host_info(&peer1, 1, shards, &key1)); let res = cache.insert(vec![info0_invalid_sig.clone(), info1.clone()]).await; // invalid signature => InvalidSignature assert_eq!( @@ -119,12 +125,14 @@ async fn too_many_shards() { let config = Config { snapshot_hosts_cache_size: 100, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + // info0 is valid - let info0 = Arc::new(make_snapshot_host_info(&peer0, 1, vec![0, 1, 2, 3], &key0)); + let info0 = Arc::new(make_snapshot_host_info(&peer0, 1, sid_vec(&[0, 1, 2, 3]), &key0)); // info1 is invalid - it has more shard ids than MAX_SHARDS_PER_SNAPSHOT_HOST_INFO let too_many_shards: Vec = - (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).collect(); + (0..(MAX_SHARDS_PER_SNAPSHOT_HOST_INFO as u64 + 1)).into_iter().map(Into::into).collect(); let info1 = Arc::new(make_snapshot_host_info(&peer1, 1, too_many_shards, &key1)); // info1.verify() should fail @@ -155,8 +163,10 @@ async fn duplicate_peer_id() { let config = Config { snapshot_hosts_cache_size: 100, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); - let info00 = Arc::new(make_snapshot_host_info(&peer0, 1, vec![0, 1, 2, 3], &key0)); - let info01 = Arc::new(make_snapshot_host_info(&peer0, 2, vec![0, 3], &key0)); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + + let info00 = Arc::new(make_snapshot_host_info(&peer0, 1, sid_vec(&[0, 1, 2, 3]), &key0)); + let info01 = Arc::new(make_snapshot_host_info(&peer0, 2, sid_vec(&[0, 3]), &key0)); let res = cache.insert(vec![info00.clone(), info01.clone()]).await; // duplicate peer ids => DuplicatePeerId assert_eq!(Some(SnapshotHostInfoError::DuplicatePeerId), res.1); @@ -182,19 +192,21 @@ async fn test_lru_eviction() { let config = Config { snapshot_hosts_cache_size: 2, part_selection_cache_batch_size: 1 }; let cache = SnapshotHostsCache::new(config); + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + // initial inserts to capacity - let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, vec![0, 1, 2, 3], &key0)); + let info0 = Arc::new(make_snapshot_host_info(&peer0, 123, sid_vec(&[0, 1, 2, 3]), &key0)); let res = cache.insert(vec![info0.clone()]).await; assert_eq!([&info0].as_set(), unwrap(&res).as_set()); assert_eq!([&info0].as_set(), cache.get_hosts().iter().collect::>()); - let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, vec![2], &key1)); + let info1 = Arc::new(make_snapshot_host_info(&peer1, 123, sid_vec(&[2]), &key1)); let res = cache.insert(vec![info1.clone()]).await; assert_eq!([&info1].as_set(), unwrap(&res).as_set()); assert_eq!([&info0, &info1].as_set(), cache.get_hosts().iter().collect::>()); // insert past capacity - let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, vec![1, 3], &key2)); + let info2 = Arc::new(make_snapshot_host_info(&peer2, 123, sid_vec(&[1, 3]), &key2)); let res = cache.insert(vec![info2.clone()]).await; // check that the new data is accepted assert_eq!([&info2].as_set(), unwrap(&res).as_set()); @@ -318,7 +330,7 @@ async fn run_select_peer_test( assert!(err.is_none()); } SelectPeerAction::CallSelect(wanted) => { - let peer = cache.select_host_for_part(sync_hash, 0, part_id); + let peer = cache.select_host_for_part(sync_hash, new_shard_id_tmp(0), part_id); let wanted = match wanted { Some(idx) => Some(&peers[*idx].peer_id), None => None, @@ -326,9 +338,10 @@ async fn run_select_peer_test( assert!(peer.as_ref() == wanted, "got: {:?} want: {:?}", &peer, &wanted); } SelectPeerAction::PartReceived => { - assert!(cache.has_selector(0, part_id)); - cache.part_received(0, part_id); - assert!(!cache.has_selector(0, part_id)); + let shard_id = new_shard_id_tmp(0); + assert!(cache.has_selector(shard_id, part_id)); + cache.part_received(shard_id, part_id); + assert!(!cache.has_selector(shard_id, part_id)); } } } @@ -342,11 +355,15 @@ async fn test_select_peer() { let part_id = 0; let num_peers = SELECT_PEER_CASES.iter().map(|t| t.num_peers).max().unwrap(); let mut peers = Vec::with_capacity(num_peers); + + let sid_vec = |v: &[u64]| v.iter().cloned().map(Into::into).collect_vec(); + for _ in 0..num_peers { let key = data::make_secret_key(&mut rng); let peer_id = PeerId::new(key.public_key()); - let score = priority_score(&peer_id, 0u64, part_id); - let info = Arc::new(SnapshotHostInfo::new(peer_id, sync_hash, 123, vec![0, 1, 2, 3], &key)); + let score = priority_score(&peer_id, new_shard_id_tmp(0), part_id); + let info = + Arc::new(SnapshotHostInfo::new(peer_id, sync_hash, 123, sid_vec(&[0, 1, 2, 3]), &key)); peers.push((info, score)); } peers.sort_by(|(_linfo, lscore), (_rinfo, rscore)| { diff --git a/core/primitives-core/src/types.rs b/core/primitives-core/src/types.rs index 7cc5e279fff..04267a9c491 100644 --- a/core/primitives-core/src/types.rs +++ b/core/primitives-core/src/types.rs @@ -18,8 +18,6 @@ pub type Nonce = u64; pub type BlockHeight = u64; /// Height of the epoch. pub type EpochHeight = u64; -/// Shard index, from 0 to NUM_SHARDS - 1. -pub type ShardId = u64; /// Balance is type for storing amounts of tokens. pub type Balance = u128; /// Gas is a type for storing amount of gas. @@ -45,3 +43,159 @@ pub type ReceiptIndex = usize; pub type PromiseId = Vec; pub type ProtocolVersion = u32; + +/// The shard identifier. The ShardId is currently being migrated to a newtype - +/// please see the new ShardId definition below. +pub type ShardId = u64; + +/// The ShardIndex is the index of the shard in an array of shard data. +/// Historically the ShardId was always in the range 0..NUM_SHARDS and was used +/// as the shard index. This is no longer the case, and the ShardIndex should be +/// used instead. +pub type ShardIndex = usize; + +// TODO(wacban) This is a temporary solution to aid the transition to having +// ShardId as a newtype. It should be replaced / removed / inlined once the +// transition is complete. +pub const fn new_shard_id_tmp(id: u64) -> ShardId { + id +} + +// TODO(wacban) This is a temporary solution to aid the transition to having +// ShardId as a newtype. It should be replaced / removed / inlined once the +// transition is complete. +pub fn new_shard_id_vec_tmp(vec: &[u64]) -> Vec { + vec.iter().copied().map(new_shard_id_tmp).collect() +} + +// TODO(wacban) This is a temporary solution to aid the transition to having +// ShardId as a newtype. It should be replaced / removed / inlined once the +// transition is complete. +pub const fn shard_id_as_u32(id: ShardId) -> u32 { + id as u32 +} + +// TODO(wacban) Complete the transition to ShardId as a newtype. +// /// The shard identifier. It may be a arbitrary number - it does not need to be +// /// a number in the range 0..NUM_SHARDS. The shard ids do not need to be +// /// sequential or contiguous. +// /// +// /// The shard id is wrapped in a newtype to prevent the old pattern of using +// /// indices in range 0..NUM_SHARDS and casting to ShardId. Once the transition +// /// if fully complete it potentially may be simplified to a regular type alias. +// #[derive( +// arbitrary::Arbitrary, +// borsh::BorshSerialize, +// borsh::BorshDeserialize, +// serde::Serialize, +// serde::Deserialize, +// Hash, +// Clone, +// Copy, +// Debug, +// PartialEq, +// Eq, +// PartialOrd, +// Ord, +// )] +// pub struct ShardId(u64); + +// impl ShardId { +// /// Create a new shard id. Please note that this function should not be used +// /// to convert a shard index (a number in 0..num_shards range) to ShardId. +// /// Instead the ShardId should be obtained from the shard_layout. +// /// +// /// ``` +// /// // BAD USAGE: +// /// for shard_index in 1..num_shards { +// /// let shard_id = ShardId::new(shard_index); // Incorrect!!! +// /// } +// /// ``` +// /// ``` +// /// // GOOD USAGE 1: +// /// for shard_index in 1..num_shards { +// /// let shard_id = shard_layout.get_shard_id(shard_index); +// /// } +// /// // GOOD USAGE 2: +// /// for shard_id in shard_layout.shard_ids() { +// /// let shard_id = shard_layout.get_shard_id(shard_index); +// /// } +// /// ``` +// pub const fn new(id: u64) -> Self { +// Self(id) +// } + +// /// Get the numerical value of the shard id. This should not be used as an +// /// index into an array, as the shard id may be any arbitrary number. +// pub fn get(self) -> u64 { +// self.0 +// } + +// pub fn to_le_bytes(self) -> [u8; 8] { +// self.0.to_le_bytes() +// } + +// pub fn from_le_bytes(bytes: [u8; 8]) -> Self { +// Self(u64::from_le_bytes(bytes)) +// } + +// // TODO This is not great, in ShardUId shard_id is u32. +// // Currently used for some metrics so kinda ok. +// pub fn max() -> Self { +// Self(u64::MAX) +// } +// } + +// impl Display for ShardId { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// write!(f, "{}", self.0) +// } +// } + +// impl From for ShardId { +// fn from(id: u64) -> Self { +// Self(id) +// } +// } + +// impl Into for ShardId { +// fn into(self) -> u64 { +// self.0 +// } +// } + +// impl From for ShardId { +// fn from(id: u32) -> Self { +// Self(id as u64) +// } +// } + +// impl Into for ShardId { +// fn into(self) -> u32 { +// self.0 as u32 +// } +// } + +// impl From for ShardId { +// fn from(id: i32) -> Self { +// Self(id as u64) +// } +// } + +// impl From for ShardId { +// fn from(id: usize) -> Self { +// Self(id as u64) +// } +// } + +// impl From for ShardId { +// fn from(id: u16) -> Self { +// Self(id as u64) +// } +// } + +// impl Into for ShardId { +// fn into(self) -> u16 { +// self.0 as u16 +// } +// } diff --git a/core/primitives/src/block.rs b/core/primitives/src/block.rs index 826801ab67e..8816c28c155 100644 --- a/core/primitives/src/block.rs +++ b/core/primitives/src/block.rs @@ -14,6 +14,7 @@ use crate::sharding::{ChunkHashHeight, ShardChunkHeader, ShardChunkHeaderV1}; use crate::types::{Balance, BlockHeight, EpochId, Gas}; use crate::version::{ProtocolVersion, SHARD_CHUNK_HEADER_UPGRADE_VERSION}; use borsh::{BorshDeserialize, BorshSerialize}; +use near_primitives_core::types::{ShardId, ShardIndex}; use near_schema_checker_lib::ProtocolSchema; use near_time::Utc; use primitive_types::U256; @@ -88,6 +89,7 @@ pub enum Block { #[cfg(feature = "solomon")] type ShardChunkReedSolomon = reed_solomon_erasure::galois_8::ReedSolomon; +/// The shard_ids, state_roots and congestion_infos must be in the same order. #[cfg(feature = "solomon")] pub fn genesis_chunks( state_roots: Vec, @@ -110,10 +112,9 @@ pub fn genesis_chunks( let num = shard_ids.len(); assert_eq!(state_roots.len(), num); - for shard_id in 0..num { - let state_root = state_roots[shard_id]; - let congestion_info = congestion_infos[shard_id]; - let shard_id = shard_id as crate::types::ShardId; + for (shard_index, &shard_id) in shard_ids.iter().enumerate() { + let state_root = state_roots[shard_index]; + let congestion_info = congestion_infos[shard_index]; let encoded_chunk = genesis_chunk( &rs, @@ -140,7 +141,7 @@ fn genesis_chunk( genesis_protocol_version: u32, genesis_height: u64, initial_gas_limit: u64, - shard_id: u64, + shard_id: ShardId, state_root: CryptoHash, congestion_info: Option, ) -> crate::sharding::EncodedShardChunk { @@ -781,7 +782,7 @@ impl<'a> ExactSizeIterator for VersionedChunksIter<'a> { } } -impl<'a> Index for ChunksCollection<'a> { +impl<'a> Index for ChunksCollection<'a> { type Output = ShardChunkHeader; /// Deprecated. Please use get instead, it's safer. @@ -808,7 +809,7 @@ impl<'a> ChunksCollection<'a> { } } - pub fn get(&self, index: usize) -> Option<&ShardChunkHeader> { + pub fn get(&self, index: ShardIndex) -> Option<&ShardChunkHeader> { match self { ChunksCollection::V1(chunks) => chunks.get(index), ChunksCollection::V2(chunks) => chunks.get(index), diff --git a/core/primitives/src/congestion_info.rs b/core/primitives/src/congestion_info.rs index 319774fa6f5..5488d5059b4 100644 --- a/core/primitives/src/congestion_info.rs +++ b/core/primitives/src/congestion_info.rs @@ -499,7 +499,9 @@ impl ShardAcceptsTransactions { #[cfg(test)] mod tests { + use itertools::Itertools; use near_parameters::RuntimeConfigStore; + use near_primitives_core::types::new_shard_id_tmp; use near_primitives_core::version::{ProtocolFeature, PROTOCOL_VERSION}; use super::*; @@ -576,7 +578,12 @@ mod tests { assert_eq!(0.0, congestion_control.outgoing_congestion()); assert_eq!(0.0, congestion_control.congestion_level()); - assert!(config.max_outgoing_gas.abs_diff(congestion_control.outgoing_gas_limit(0)) <= 1); + assert!( + config + .max_outgoing_gas + .abs_diff(congestion_control.outgoing_gas_limit(new_shard_id_tmp(0))) + <= 1 + ); assert!(config.max_tx_gas.abs_diff(congestion_control.process_tx_limit()) <= 1); assert!(congestion_control.shard_accepts_transactions().is_yes()); @@ -599,7 +606,7 @@ mod tests { let control = CongestionControl::new(config, info, 0); assert_eq!(1.0, control.congestion_level()); // fully congested, no more forwarding allowed - assert_eq!(0, control.outgoing_gas_limit(1)); + assert_eq!(0, control.outgoing_gas_limit(new_shard_id_tmp(1))); assert!(control.shard_accepts_transactions().is_no()); // processing to other shards is not restricted by memory congestion assert_eq!(config.max_tx_gas, control.process_tx_limit()); @@ -613,7 +620,7 @@ mod tests { assert_eq!( (0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 50%, still no new transactions are allowed assert!(control.shard_accepts_transactions().is_no()); @@ -627,7 +634,7 @@ mod tests { assert_eq!( (0.125 * config.min_outgoing_gas as f64 + 0.875 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 12.5%, new transactions are allowed (threshold is 0.25) assert!(control.shard_accepts_transactions().is_yes()); @@ -651,7 +658,7 @@ mod tests { let control = CongestionControl::new(config, info, 0); assert_eq!(1.0, control.congestion_level()); // fully congested, no more forwarding allowed - assert_eq!(0, control.outgoing_gas_limit(1)); + assert_eq!(0, control.outgoing_gas_limit(new_shard_id_tmp(1))); assert!(control.shard_accepts_transactions().is_no()); // processing to other shards is restricted by own incoming congestion assert_eq!(config.min_tx_gas, control.process_tx_limit()); @@ -665,7 +672,7 @@ mod tests { assert_eq!( (0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 50%, still no new transactions to us are allowed assert!(control.shard_accepts_transactions().is_no()); @@ -684,7 +691,7 @@ mod tests { assert_eq!( (0.125 * config.min_outgoing_gas as f64 + 0.875 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 12.5%, new transactions are allowed (threshold is 0.25) assert!(control.shard_accepts_transactions().is_yes()); @@ -711,7 +718,7 @@ mod tests { let control = CongestionControl::new(config, info, 0); assert_eq!(1.0, control.congestion_level()); // fully congested, no more forwarding allowed - assert_eq!(0, control.outgoing_gas_limit(1)); + assert_eq!(0, control.outgoing_gas_limit(new_shard_id_tmp(1))); assert!(control.shard_accepts_transactions().is_no()); // processing to other shards is not restricted by own outgoing congestion assert_eq!(config.max_tx_gas, control.process_tx_limit()); @@ -722,7 +729,7 @@ mod tests { assert_eq!(0.5, control.congestion_level()); assert_eq!( (0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 50%, still no new transactions to us are allowed assert!(control.shard_accepts_transactions().is_no()); @@ -734,7 +741,7 @@ mod tests { assert_eq!( (0.125 * config.min_outgoing_gas as f64 + 0.875 * config.max_outgoing_gas as f64) as u64, - control.outgoing_gas_limit(1) + control.outgoing_gas_limit(new_shard_id_tmp(1)) ); // at 12.5%, new transactions are allowed (threshold is 0.25) assert!(control.shard_accepts_transactions().is_yes()); @@ -802,8 +809,8 @@ mod tests { let mut info = CongestionInfo::default(); info.add_buffered_receipt_gas(config.max_congestion_outgoing_gas / 2).unwrap(); - let shard = 2; - let all_shards = [0, 1, 2, 3, 4]; + let shard = new_shard_id_tmp(2); + let all_shards = [0, 1, 2, 3, 4].into_iter().map(new_shard_id_tmp).collect_vec(); // Test without missed chunks congestion. @@ -813,7 +820,7 @@ mod tests { let expected_outgoing_limit = 0.5 * config.min_outgoing_gas as f64 + 0.5 * config.max_outgoing_gas as f64; - for shard in all_shards { + for &shard in &all_shards { assert_eq!(control.outgoing_gas_limit(shard), expected_outgoing_limit as u64); } @@ -825,7 +832,7 @@ mod tests { let expected_outgoing_limit = mix(config.max_outgoing_gas, config.min_outgoing_gas, 0.8) as f64; - for shard in all_shards { + for &shard in &all_shards { assert_eq!(control.outgoing_gas_limit(shard), expected_outgoing_limit as u64); } diff --git a/core/primitives/src/epoch_info.rs b/core/primitives/src/epoch_info.rs index bcc14c12c49..47e6277d247 100644 --- a/core/primitives/src/epoch_info.rs +++ b/core/primitives/src/epoch_info.rs @@ -3,6 +3,7 @@ use smart_default::SmartDefault; use std::collections::{BTreeMap, HashMap}; use crate::rand::WeightedIndex; +use crate::shard_layout::ShardLayout; use crate::types::validator_stake::{ValidatorStake, ValidatorStakeIter}; use crate::types::{AccountId, ValidatorKickoutReason, ValidatorStakeV1}; use crate::validator_mandates::ValidatorMandates; @@ -601,35 +602,35 @@ impl EpochInfo { pub fn sample_chunk_producer( &self, - height: BlockHeight, + shard_layout: &ShardLayout, shard_id: ShardId, + height: BlockHeight, ) -> Option { + let shard_index = shard_layout.get_shard_index(shard_id); match &self { Self::V1(v1) => { let cp_settlement = &v1.chunk_producers_settlement; - let shard_cps = cp_settlement.get(shard_id as usize)?; + let shard_cps = cp_settlement.get(shard_index)?; shard_cps.get((height as u64 % (shard_cps.len() as u64)) as usize).copied() } Self::V2(v2) => { let cp_settlement = &v2.chunk_producers_settlement; - let shard_cps = cp_settlement.get(shard_id as usize)?; + let shard_cps = cp_settlement.get(shard_index)?; shard_cps.get((height as u64 % (shard_cps.len() as u64)) as usize).copied() } Self::V3(v3) => { let protocol_version = self.protocol_version(); let seed = Self::chunk_produce_seed(protocol_version, &v3.rng_seed, height, shard_id); - let shard_id = shard_id as usize; - let sample = v3.chunk_producers_sampler.get(shard_id)?.sample(seed); - v3.chunk_producers_settlement.get(shard_id)?.get(sample).copied() + let sample = v3.chunk_producers_sampler.get(shard_index)?.sample(seed); + v3.chunk_producers_settlement.get(shard_index)?.get(sample).copied() } Self::V4(v4) => { let protocol_version = self.protocol_version(); let seed = Self::chunk_produce_seed(protocol_version, &v4.rng_seed, height, shard_id); - let shard_id = shard_id as usize; - let sample = v4.chunk_producers_sampler.get(shard_id)?.sample(seed); - v4.chunk_producers_settlement.get(shard_id)?.get(sample).copied() + let sample = v4.chunk_producers_sampler.get(shard_index)?.sample(seed); + v4.chunk_producers_settlement.get(shard_index)?.get(sample).copied() } } } diff --git a/core/primitives/src/shard_layout.rs b/core/primitives/src/shard_layout.rs index 1a15b88161a..717b462c8db 100644 --- a/core/primitives/src/shard_layout.rs +++ b/core/primitives/src/shard_layout.rs @@ -2,9 +2,9 @@ use crate::hash::CryptoHash; use crate::types::{AccountId, NumShards}; use borsh::{BorshDeserialize, BorshSerialize}; use itertools::Itertools; -use near_primitives_core::types::ShardId; +use near_primitives_core::types::{ShardId, ShardIndex}; use near_schema_checker_lib::ProtocolSchema; -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::{fmt, str}; /// This file implements two data structure `ShardLayout` and `ShardUId` @@ -88,6 +88,18 @@ type ShardsSplitMapV2 = BTreeMap>; /// A mapping from the child shard to the parent shard. type ShardsParentMapV2 = BTreeMap; +fn new_shard_ids_vec(shard_ids: Vec) -> Vec { + shard_ids.into_iter().map(Into::into).collect() +} + +fn new_shards_split_map(shards_split_map: Vec>) -> ShardsSplitMap { + shards_split_map.into_iter().map(new_shard_ids_vec).collect() +} + +fn new_shards_split_map_v2(shards_split_map: BTreeMap>) -> ShardsSplitMapV2 { + shards_split_map.into_iter().map(|(k, v)| (k.into(), new_shard_ids_vec(v))).collect() +} + #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)] pub struct ShardLayoutV1 { /// The boundary accounts are the accounts on boundaries between shards. @@ -110,14 +122,14 @@ impl ShardLayoutV1 { // In this shard layout the accounts are divided into ranges, each range is // mapped to a shard. The shards are contiguous and start from 0. fn account_id_to_shard_id(&self, account_id: &AccountId) -> ShardId { - let mut shard_id: ShardId = 0; + let mut shard_id: u64 = 0; for boundary_account in &self.boundary_accounts { if account_id < boundary_account { break; } shard_id += 1; } - shard_id + shard_id.into() } } @@ -138,9 +150,15 @@ pub struct ShardLayoutV2 { /// /// The shard id at index i corresponds to the shard with account range: /// [boundary_accounts[i -1], boundary_accounts[i]). - /// shard_ids: Vec, + /// The mapping from shard id to shard index. + id_to_index_map: BTreeMap, + + /// The mapping from shard index to shard id. + /// TODO(wacban) this is identical to the shard_ids, remove it. + index_to_id_map: BTreeMap, + /// A mapping from the parent shard to child shards. Maps shards from the /// previous shard layout to shards that they split to in this shard layout. shards_split_map: Option, @@ -192,16 +210,17 @@ impl ShardLayout { version: ShardVersion, ) -> Self { let to_parent_shard_map = if let Some(shards_split_map) = &shards_split_map { - let mut to_parent_shard_map = HashMap::new(); + let mut to_parent_shard_map = BTreeMap::new(); let num_shards = (boundary_accounts.len() + 1) as NumShards; for (parent_shard_id, shard_ids) in shards_split_map.iter().enumerate() { + let parent_shard_id = parent_shard_id as u64; for &shard_id in shard_ids { - let prev = to_parent_shard_map.insert(shard_id, parent_shard_id as ShardId); + let prev = to_parent_shard_map.insert(shard_id, parent_shard_id); assert!(prev.is_none(), "no shard should appear in the map twice"); assert!(shard_id < num_shards, "shard id should be valid"); } } - Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id]).collect()) + Some((0..num_shards).map(|shard_id| to_parent_shard_map[&shard_id.into()]).collect()) } else { None }; @@ -225,10 +244,19 @@ impl ShardLayout { assert_eq!(boundary_accounts.len() + 1, shard_ids.len()); assert_eq!(boundary_accounts, boundary_accounts.iter().sorted().cloned().collect_vec()); + let mut id_to_index_map = BTreeMap::new(); + let mut index_to_id_map = BTreeMap::new(); + for (shard_index, &shard_id) in shard_ids.iter().enumerate() { + id_to_index_map.insert(shard_id, shard_index); + index_to_id_map.insert(shard_index, shard_id); + } + let Some(shards_split_map) = shards_split_map else { return Self::V2(ShardLayoutV2 { boundary_accounts, shard_ids, + id_to_index_map, + index_to_id_map, shards_split_map: None, shards_parent_map: None, version: VERSION, @@ -253,6 +281,8 @@ impl ShardLayout { Self::V2(ShardLayoutV2 { boundary_accounts, shard_ids, + id_to_index_map, + index_to_id_map, shards_split_map, shards_parent_map, version: VERSION, @@ -263,7 +293,7 @@ impl ShardLayout { pub fn v1_test() -> Self { ShardLayout::v1( vec!["abc", "foo", "test0"].into_iter().map(|s| s.parse().unwrap()).collect(), - Some(vec![vec![0, 1, 2, 3]]), + Some(new_shards_split_map(vec![vec![0, 1, 2, 3]])), 1, ) } @@ -275,7 +305,7 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0, 1, 2, 3]]), + Some(new_shards_split_map(vec![vec![0, 1, 2, 3]])), 1, ) } @@ -287,7 +317,7 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0], vec![1], vec![2], vec![3, 4]]), + Some(new_shards_split_map(vec![vec![0], vec![1], vec![2], vec![3, 4]])), 2, ) } @@ -305,7 +335,7 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0], vec![1], vec![2, 3], vec![4], vec![5]]), + Some(new_shards_split_map(vec![vec![0], vec![1], vec![2, 3], vec![4], vec![5]])), 3, ) } @@ -331,6 +361,7 @@ impl ShardLayout { ]; let shard_ids = vec![0, 1, 6, 7, 3, 4, 5]; + let shard_ids = new_shard_ids_vec(shard_ids); let shards_split_map = BTreeMap::from([ (0, vec![0]), @@ -340,6 +371,7 @@ impl ShardLayout { (4, vec![4]), (5, vec![5]), ]); + let shards_split_map = new_shards_split_map_v2(shards_split_map); let shards_split_map = Some(shards_split_map); ShardLayout::v2(boundary_accounts, shard_ids, shards_split_map) @@ -361,7 +393,14 @@ impl ShardLayout { .into_iter() .map(|s| s.parse().unwrap()) .collect(), - Some(vec![vec![0], vec![1], vec![2], vec![3], vec![4, 5], vec![6]]), + Some(new_shards_split_map(vec![ + vec![0], + vec![1], + vec![2], + vec![3], + vec![4, 5], + vec![6], + ])), 4, ) } @@ -380,7 +419,10 @@ impl ShardLayout { match self { Self::V0(_) => None, Self::V1(v1) => match &v1.shards_split_map { - Some(shards_split_map) => shards_split_map.get(parent_shard_id as usize).cloned(), + Some(shards_split_map) => { + let parent_shard_index = parent_shard_id as usize; + shards_split_map.get(parent_shard_index).cloned() + } None => None, }, Self::V2(v2) => match &v2.shards_split_map { @@ -403,7 +445,10 @@ impl ShardLayout { Self::V1(v1) => match &v1.to_parent_shard_map { // we can safely unwrap here because the construction of to_parent_shard_map guarantees // that every shard has a parent shard - Some(to_parent_shard_map) => *to_parent_shard_map.get(shard_id as usize).unwrap(), + Some(to_parent_shard_map) => { + let shard_index = self.get_shard_index(shard_id); + *to_parent_shard_map.get(shard_index).unwrap() + } None => panic!("shard_layout has no parent shard"), }, Self::V2(v2) => match &v2.shards_parent_map { @@ -441,8 +486,8 @@ impl ShardLayout { pub fn shard_ids(&self) -> impl Iterator + '_ { match self { - Self::V0(_) => (0..self.num_shards()).collect_vec().into_iter(), - Self::V1(_) => (0..self.num_shards()).collect_vec().into_iter(), + Self::V0(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(), + Self::V1(_) => (0..self.num_shards()).map(Into::into).collect_vec().into_iter(), Self::V2(v2) => v2.shard_ids.clone().into_iter(), } } @@ -452,6 +497,26 @@ impl ShardLayout { pub fn shard_uids(&self) -> impl Iterator + '_ { self.shard_ids().map(|shard_id| ShardUId::from_shard_id_and_layout(shard_id, self)) } + + /// Returns the shard index for a given shard id. The shard index should be + /// used when indexing into an array of chunk data. + pub fn get_shard_index(&self, shard_id: ShardId) -> ShardIndex { + match self { + Self::V0(_) => shard_id as ShardIndex, + Self::V1(_) => shard_id as ShardIndex, + Self::V2(v2) => v2.id_to_index_map[&shard_id], + } + } + + /// Get the shard id for a given shard index. The shard id should be used to + /// identify the shard and starting from the ShardLayoutV2 it is unique. + pub fn get_shard_id(&self, shard_index: usize) -> ShardId { + match self { + Self::V0(_) => shard_index as ShardId, + Self::V1(_) => shard_index as ShardId, + Self::V2(v2) => v2.index_to_id_map[&shard_index], + } + } } /// Maps an account to the shard that it belongs to given a shard_layout @@ -464,7 +529,8 @@ pub fn account_id_to_shard_id(account_id: &AccountId, shard_layout: &ShardLayout ShardLayout::V0(ShardLayoutV0 { num_shards, .. }) => { let hash = CryptoHash::hash_bytes(account_id.as_bytes()); let (bytes, _) = stdx::split_array::<32, 8, 24>(hash.as_bytes()); - u64::from_le_bytes(*bytes) % num_shards + let shard_id = u64::from_le_bytes(*bytes) % num_shards; + shard_id.into() } ShardLayout::V1(v1) => v1.account_id_to_shard_id(account_id), ShardLayout::V2(v2) => v2.account_id_to_shard_id(account_id), @@ -536,7 +602,7 @@ impl ShardUId { /// Returns shard id pub fn shard_id(&self) -> ShardId { - ShardId::from(self.shard_id) + self.shard_id.into() } } @@ -679,9 +745,12 @@ impl<'de> serde::de::Visitor<'de> for ShardUIdVisitor { #[cfg(test)] mod tests { use crate::epoch_manager::{AllEpochConfig, EpochConfig, ValidatorSelectionConfig}; - use crate::shard_layout::{account_id_to_shard_id, ShardLayout, ShardLayoutV1, ShardUId}; + use crate::shard_layout::{ + account_id_to_shard_id, new_shard_ids_vec, new_shards_split_map, ShardLayout, + ShardLayoutV1, ShardUId, + }; use itertools::Itertools; - use near_primitives_core::types::ProtocolVersion; + use near_primitives_core::types::{new_shard_id_tmp, ProtocolVersion}; use near_primitives_core::types::{AccountId, ShardId}; use near_primitives_core::version::{ProtocolFeature, PROTOCOL_VERSION}; use rand::distributions::Alphanumeric; @@ -689,7 +758,7 @@ mod tests { use rand::{Rng, SeedableRng}; use std::collections::{BTreeMap, HashMap}; - use super::{ShardVersion, ShardsSplitMap}; + use super::{new_shards_split_map_v2, ShardVersion, ShardsSplitMap}; // The old ShardLayoutV1, before fixed shards were removed. tests only #[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq, Eq)] @@ -750,8 +819,8 @@ mod tests { fn test_shard_layout_v0() { let num_shards = 4; let shard_layout = ShardLayout::v0(num_shards, 0); - let mut shard_id_distribution: HashMap<_, _> = - shard_layout.shard_ids().map(|shard_id| (shard_id, 0)).collect(); + let mut shard_id_distribution: HashMap = + shard_layout.shard_ids().map(|shard_id| (shard_id.into(), 0)).collect(); let mut rng = StdRng::from_seed([0; 32]); for _i in 0..1000 { let s: Vec = (&mut rng).sample_iter(&Alphanumeric).take(10).collect(); @@ -759,7 +828,7 @@ mod tests { let account_id = s.to_lowercase().parse().unwrap(); let shard_id = account_id_to_shard_id(&account_id, &shard_layout); assert!(shard_id < num_shards); - *shard_id_distribution.get_mut(&shard_id).unwrap() += 1; + *shard_id_distribution.get_mut(&shard_id.into()).unwrap() += 1; } let expected_distribution: HashMap<_, _> = [(0, 247), (1, 268), (2, 233), (3, 252)].into_iter().collect(); @@ -770,38 +839,39 @@ mod tests { fn test_shard_layout_v1() { let shard_layout = ShardLayout::v1( parse_account_ids(&["aurora", "bar", "foo", "foo.baz", "paz"]), - Some(vec![vec![0, 1, 2], vec![3, 4, 5]]), + Some(new_shards_split_map(vec![vec![0, 1, 2], vec![3, 4, 5]])), 1, ); assert_eq!( - shard_layout.get_children_shards_uids(0).unwrap(), + shard_layout.get_children_shards_uids(new_shard_id_tmp(0)).unwrap(), (0..3).map(|x| ShardUId { version: 1, shard_id: x }).collect::>() ); assert_eq!( - shard_layout.get_children_shards_uids(1).unwrap(), + shard_layout.get_children_shards_uids(new_shard_id_tmp(1)).unwrap(), (3..6).map(|x| ShardUId { version: 1, shard_id: x }).collect::>() ); for x in 0..3 { - assert_eq!(shard_layout.get_parent_shard_id(x).unwrap(), 0); - assert_eq!(shard_layout.get_parent_shard_id(x + 3).unwrap(), 1); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(x)).unwrap(), 0); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(x + 3)).unwrap(), 1); } - assert_eq!(account_id_to_shard_id(&"aurora".parse().unwrap(), &shard_layout), 1); - assert_eq!(account_id_to_shard_id(&"foo.aurora".parse().unwrap(), &shard_layout), 3); - assert_eq!(account_id_to_shard_id(&"bar.foo.aurora".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"bar".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"bar.bar".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"foo".parse().unwrap(), &shard_layout), 3); - assert_eq!(account_id_to_shard_id(&"baz.foo".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"foo.baz".parse().unwrap(), &shard_layout), 4); - assert_eq!(account_id_to_shard_id(&"a.foo.baz".parse().unwrap(), &shard_layout), 0); - - assert_eq!(account_id_to_shard_id(&"aaa".parse().unwrap(), &shard_layout), 0); - assert_eq!(account_id_to_shard_id(&"abc".parse().unwrap(), &shard_layout), 0); - assert_eq!(account_id_to_shard_id(&"bbb".parse().unwrap(), &shard_layout), 2); - assert_eq!(account_id_to_shard_id(&"foo.goo".parse().unwrap(), &shard_layout), 4); - assert_eq!(account_id_to_shard_id(&"goo".parse().unwrap(), &shard_layout), 4); - assert_eq!(account_id_to_shard_id(&"zoo".parse().unwrap(), &shard_layout), 5); + let aid = |s: &str| s.parse().unwrap(); + assert_eq!(account_id_to_shard_id(&aid("aurora"), &shard_layout), 1); + assert_eq!(account_id_to_shard_id(&aid("foo.aurora"), &shard_layout), 3); + assert_eq!(account_id_to_shard_id(&aid("bar.foo.aurora"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("bar"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("bar.bar"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("foo"), &shard_layout), 3); + assert_eq!(account_id_to_shard_id(&aid("baz.foo"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("foo.baz"), &shard_layout), 4); + assert_eq!(account_id_to_shard_id(&aid("a.foo.baz"), &shard_layout), 0); + + assert_eq!(account_id_to_shard_id(&aid("aaa"), &shard_layout), 0); + assert_eq!(account_id_to_shard_id(&aid("abc"), &shard_layout), 0); + assert_eq!(account_id_to_shard_id(&aid("bbb"), &shard_layout), 2); + assert_eq!(account_id_to_shard_id(&aid("foo.goo"), &shard_layout), 4); + assert_eq!(account_id_to_shard_id(&aid("goo"), &shard_layout), 4); + assert_eq!(account_id_to_shard_id(&aid("zoo"), &shard_layout), 5); } // check that after removing the fixed shards from the shard layout v1 @@ -812,8 +882,8 @@ mod tests { let old = OldShardLayoutV1 { fixed_shards: vec![], boundary_accounts: parse_account_ids(&["aaa", "bbb"]), - shards_split_map: Some(vec![vec![0, 1, 2]]), - to_parent_shard_map: Some(vec![0, 0, 0]), + shards_split_map: Some(new_shards_split_map(vec![vec![0, 1, 2]])), + to_parent_shard_map: Some(new_shard_ids_vec(vec![0, 0, 0])), version: 1, }; let json = serde_json::to_string_pretty(&old).unwrap(); @@ -847,7 +917,7 @@ mod tests { assert_eq!(account_id_to_shard_id(&"ppp".parse().unwrap(), &shard_layout), 7); // check shard ids - assert_eq!(shard_layout.shard_ids().collect_vec(), vec![3, 8, 4, 7]); + assert_eq!(shard_layout.shard_ids().collect_vec(), new_shard_ids_vec(vec![3, 8, 4, 7])); // check shard uids let version = 3; @@ -855,15 +925,24 @@ mod tests { assert_eq!(shard_layout.shard_uids().collect_vec(), vec![u(3), u(8), u(4), u(7)]); // check parent - assert_eq!(shard_layout.get_parent_shard_id(3).unwrap(), 3); - assert_eq!(shard_layout.get_parent_shard_id(8).unwrap(), 1); - assert_eq!(shard_layout.get_parent_shard_id(4).unwrap(), 4); - assert_eq!(shard_layout.get_parent_shard_id(7).unwrap(), 1); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(3)).unwrap(), 3); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(8)).unwrap(), 1); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(4)).unwrap(), 4); + assert_eq!(shard_layout.get_parent_shard_id(new_shard_id_tmp(7)).unwrap(), 1); // check child - assert_eq!(shard_layout.get_children_shards_ids(1).unwrap(), vec![7, 8]); - assert_eq!(shard_layout.get_children_shards_ids(3).unwrap(), vec![3]); - assert_eq!(shard_layout.get_children_shards_ids(4).unwrap(), vec![4]); + assert_eq!( + shard_layout.get_children_shards_ids(new_shard_id_tmp(1)).unwrap(), + new_shard_ids_vec(vec![7, 8]) + ); + assert_eq!( + shard_layout.get_children_shards_ids(new_shard_id_tmp(3)).unwrap(), + new_shard_ids_vec(vec![3]) + ); + assert_eq!( + shard_layout.get_children_shards_ids(new_shard_id_tmp(4)).unwrap(), + new_shard_ids_vec(vec![4]) + ); } fn get_test_shard_layout_v2() -> ShardLayout { @@ -873,10 +952,12 @@ mod tests { let boundary_accounts = vec![b0, b1, b2]; let shard_ids = vec![3, 8, 4, 7]; + let shard_ids = new_shard_ids_vec(shard_ids); // the mapping from parent to the child // shard 1 is split into shards 7 & 8 while other shards stay the same let shards_split_map = BTreeMap::from([(1, vec![7, 8]), (3, vec![3]), (4, vec![4])]); + let shards_split_map = new_shards_split_map_v2(shards_split_map); let shards_split_map = Some(shards_split_map); ShardLayout::v2(boundary_accounts, shard_ids, shards_split_map) @@ -1020,6 +1101,24 @@ mod tests { 4, 5 ], + "id_to_index_map": { + "0": 0, + "1": 1, + "3": 4, + "4": 5, + "5": 6, + "6": 2, + "7": 3 + }, + "index_to_id_map": { + "0": 0, + "1": 1, + "2": 6, + "3": 7, + "4": 3, + "5": 4, + "6": 5 + }, "shards_split_map": { "0": [ 0 diff --git a/core/primitives/src/sharding.rs b/core/primitives/src/sharding.rs index a9a1c7fa6e9..4e5e26cd867 100644 --- a/core/primitives/src/sharding.rs +++ b/core/primitives/src/sharding.rs @@ -1213,7 +1213,7 @@ impl EncodedShardChunk { "decode_chunk", data_parts, height_included = self.cloned_header().height_included(), - shard_id = self.cloned_header().shard_id(), + shard_id = ?self.cloned_header().shard_id(), chunk_hash = ?self.chunk_hash()) .entered(); diff --git a/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs b/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs index 26874117ecd..ae28fa9efb5 100644 --- a/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs +++ b/core/primitives/src/stateless_validation/chunk_endorsements_bitmap.rs @@ -1,6 +1,5 @@ use bitvec::prelude::*; use borsh::{BorshDeserialize, BorshSerialize}; -use near_primitives_core::types::ShardId; use near_schema_checker_lib::ProtocolSchema; /// Represents a collection of bitmaps, one per shard, to store whether the endorsements from the chunk validators has been received. @@ -56,21 +55,21 @@ impl ChunkEndorsementsBitmap { // Creates an endorsement bitmap for all the shards. pub fn from_endorsements(shards_to_endorsements: Vec>) -> Self { let mut bitmap = ChunkEndorsementsBitmap::new(shards_to_endorsements.len()); - for (shard_id, endorsements) in shards_to_endorsements.into_iter().enumerate() { - bitmap.add_endorsements(shard_id as ShardId, endorsements); + for (shard_index, endorsements) in shards_to_endorsements.into_iter().enumerate() { + bitmap.add_endorsements(shard_index, endorsements); } bitmap } /// Adds the provided endorsements to the bitmap for the specified shard. - pub fn add_endorsements(&mut self, shard_id: ShardId, endorsements: Vec) { + pub fn add_endorsements(&mut self, shard_index: usize, endorsements: Vec) { let bitvec: BitVecType = endorsements.iter().collect(); - self.inner[shard_id as usize] = bitvec.into(); + self.inner[shard_index] = bitvec.into(); } /// Returns an iterator over the endorsements (yields true if the endorsement for the respective position was received). - pub fn iter(&self, shard_id: ShardId) -> Box> { - let bitvec = BitVecType::from_vec(self.inner[shard_id as usize].clone()); + pub fn iter(&self, shard_index: usize) -> Box> { + let bitvec = BitVecType::from_vec(self.inner[shard_index].clone()); Box::new(bitvec.into_iter()) } @@ -81,8 +80,8 @@ impl ChunkEndorsementsBitmap { /// Returns the full length of the bitmap for a given shard. /// Note that the size may be greater than the number of validator assignments. - pub fn len(&self, shard_id: ShardId) -> Option { - self.inner.get(shard_id as usize).map(|v| v.len() * 8) + pub fn len(&self, shard_index: usize) -> Option { + self.inner.get(shard_index).map(|v| v.len() * 8) } } @@ -90,7 +89,6 @@ impl ChunkEndorsementsBitmap { mod tests { use super::ChunkEndorsementsBitmap; use itertools::Itertools; - use near_primitives_core::types::ShardId; use rand::Rng; const NUM_SHARDS: usize = 4; @@ -102,9 +100,9 @@ mod tests { expected_endorsements: &Vec>, ) { // Endorsements from the bitmap iterator must match the endorsements given previously. - for (shard_id, endorsements) in expected_endorsements.iter().enumerate() { - let num_bits = bitmap.len(shard_id as ShardId).unwrap(); - let bits = bitmap.iter(shard_id as ShardId).collect_vec(); + for (shard_index, endorsements) in expected_endorsements.iter().enumerate() { + let num_bits = bitmap.len(shard_index).unwrap(); + let bits = bitmap.iter(shard_index).collect_vec(); // Number of bits must be equal to the size of the bit iterator for the corresponding shard. assert_eq!(num_bits, bits.len()); // Bitmap must contain the minimal number of bits to represent the endorsements. @@ -121,13 +119,13 @@ mod tests { let mut rng = rand::thread_rng(); let mut bitmap = ChunkEndorsementsBitmap::new(NUM_SHARDS); let mut expected_endorsements = vec![]; - for shard_id in 0..NUM_SHARDS { + for shard_index in 0..NUM_SHARDS { let mut endorsements = vec![false; num_assignments]; for _ in 0..num_produced { endorsements[rng.gen_range(0..num_assignments)] = true; } expected_endorsements.push(endorsements.clone()); - bitmap.add_endorsements(shard_id as ShardId, endorsements); + bitmap.add_endorsements(shard_index, endorsements); } // Check before serialization. assert_bitmap(&bitmap, num_assignments, &expected_endorsements); diff --git a/core/store/benches/finalize_bench.rs b/core/store/benches/finalize_bench.rs index aedce9232ec..49d7915ada4 100644 --- a/core/store/benches/finalize_bench.rs +++ b/core/store/benches/finalize_bench.rs @@ -30,7 +30,7 @@ use near_primitives::sharding::{ ShardChunkV2, ShardProof, }; use near_primitives::transaction::{Action, FunctionCallAction, SignedTransaction}; -use near_primitives::types::AccountId; +use near_primitives::types::{new_shard_id_tmp, AccountId, ShardId}; use near_primitives::validator_signer::InMemoryValidatorSigner; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::DBCol; @@ -115,7 +115,7 @@ fn create_benchmark_receipts() -> Vec { ] } -fn create_chunk_header(height: u64, shard_id: u64) -> ShardChunkHeader { +fn create_chunk_header(height: u64, shard_id: ShardId) -> ShardChunkHeader { let congestion_info = ProtocolFeature::CongestionControl .enabled(PROTOCOL_VERSION) .then_some(CongestionInfo::default()); @@ -177,7 +177,7 @@ fn create_shard_chunk( ) -> ShardChunk { ShardChunk::V2(ShardChunkV2 { chunk_hash: chunk_hash.clone(), - header: create_chunk_header(0, 0), + header: create_chunk_header(0, new_shard_id_tmp(0)), transactions, prev_outgoing_receipts: receipts, }) @@ -198,7 +198,7 @@ fn create_encoded_shard_chunk( Default::default(), Default::default(), Default::default(), - Default::default(), + new_shard_id_tmp(0), Default::default(), Default::default(), Default::default(), @@ -231,8 +231,8 @@ fn encoded_chunk_to_partial_encoded_chunk( let receipt_proofs = proofs .into_iter() .enumerate() - .map(move |(proof_shard_id, proof)| { - let proof_shard_id = proof_shard_id as u64; + .map(move |(proof_shard_index, proof)| { + let proof_shard_id = shard_layout.get_shard_id(proof_shard_index); let receipts = receipts_by_shard.remove(&proof_shard_id).unwrap_or_else(Vec::new); let shard_proof = ShardProof { from_shard_id: shard_id, to_shard_id: proof_shard_id, proof }; diff --git a/core/store/src/flat/storage.rs b/core/store/src/flat/storage.rs index 172928d35eb..79f7001aacf 100644 --- a/core/store/src/flat/storage.rs +++ b/core/store/src/flat/storage.rs @@ -130,7 +130,7 @@ impl FlatStorageInner { if blocks.len() >= Self::HOPS_LIMIT { warn!( target: "chain", - shard_id = self.shard_uid.shard_id(), + shard_id = ?self.shard_uid.shard_id(), flat_head_height = flat_head.height, cached_deltas = self.deltas.len(), num_hops = blocks.len(), @@ -160,7 +160,7 @@ impl FlatStorageInner { if cached_changes_size_bytes >= Self::CACHED_CHANGES_SIZE_LIMIT { warn!( target: "chain", - shard_id = self.shard_uid.shard_id(), + shard_id = ?self.shard_uid.shard_id(), flat_head_height = self.flat_head.height, cached_deltas, %cached_changes_size_bytes, @@ -380,7 +380,7 @@ impl FlatStorage { let shard_uid = guard.shard_uid; let shard_id = shard_uid.shard_id(); - tracing::debug!(target: "store", flat_head = ?guard.flat_head.hash, ?new_head, shard_id, "Moving flat head"); + tracing::debug!(target: "store", flat_head = ?guard.flat_head.hash, ?new_head, ?shard_id, "Moving flat head"); let blocks = guard.get_blocks_to_head(&new_head)?; for block_hash in blocks.into_iter().rev() { diff --git a/core/store/src/genesis/initialization.rs b/core/store/src/genesis/initialization.rs index bc9b1c9a65c..210106e91dd 100644 --- a/core/store/src/genesis/initialization.rs +++ b/core/store/src/genesis/initialization.rs @@ -2,7 +2,11 @@ //! We first check if store has the genesis hash and state_roots, if not, we go ahead with initialization use rayon::prelude::*; -use std::{collections::HashSet, fs, path::Path}; +use std::{ + collections::{BTreeMap, HashSet}, + fs, + path::Path, +}; use borsh::BorshDeserialize; use near_chain_configs::{Genesis, GenesisContents}; @@ -110,17 +114,18 @@ fn genesis_state_from_genesis( let runtime_config_store = RuntimeConfigStore::for_chain_id(&genesis.config.chain_id); let runtime_config = runtime_config_store.get_config(genesis.config.protocol_version); let storage_usage_config = &runtime_config.fees.storage_usage_config; + let shard_ids: Vec<_> = shard_layout.shard_ids().collect(); let shard_uids: Vec<_> = shard_layout.shard_uids().collect(); - // note that here we are depending on the behavior that shard_layout.shard_uids() returns an iterator - // in order by shard id from 0 to num_shards() - let mut shard_account_ids: Vec> = - shard_uids.iter().map(|_| HashSet::new()).collect(); + + let mut shard_account_ids: BTreeMap> = + shard_ids.iter().map(|&shard_id| (shard_id, HashSet::new())).collect(); let mut has_protocol_account = false; info!(target: "store","distributing records to shards"); genesis.for_each_record(|record: &StateRecord| { - shard_account_ids[state_record_to_shard_id(record, &shard_layout) as usize] - .insert(state_record_to_account_id(record).clone()); + let shard_id = state_record_to_shard_id(record, &shard_layout); + let account_id = state_record_to_account_id(record).clone(); + shard_account_ids.get_mut(&shard_id).unwrap().insert(account_id); if let StateRecord::Account { account_id, .. } = record { if account_id == &genesis.config.protocol_treasury_account { has_protocol_account = true; @@ -165,7 +170,7 @@ fn genesis_state_from_genesis( &validators, storage_usage_config, genesis, - shard_account_ids[shard_id as usize].clone(), + shard_account_ids[&shard_id].clone(), ) }) .collect() diff --git a/core/store/src/trie/prefetching_trie_storage.rs b/core/store/src/trie/prefetching_trie_storage.rs index 31e831e331c..bd886ab32b6 100644 --- a/core/store/src/trie/prefetching_trie_storage.rs +++ b/core/store/src/trie/prefetching_trie_storage.rs @@ -593,12 +593,13 @@ mod tests_utils { mod tests { use super::{PrefetchStagingArea, PrefetcherResult}; use near_primitives::hash::CryptoHash; + use near_primitives::types::new_shard_id_tmp; #[test] fn test_prefetch_staging_area_blocking_get_after_update() { let key = CryptoHash::hash_bytes(&[1, 2, 3]); let value: std::sync::Arc<[u8]> = vec![4, 5, 6].into(); - let prefetch_staging_area = PrefetchStagingArea::new(0); + let prefetch_staging_area = PrefetchStagingArea::new(new_shard_id_tmp(0)); assert!(matches!( prefetch_staging_area.get_or_set_fetching(key), PrefetcherResult::SlotReserved diff --git a/core/store/src/trie/shard_tries.rs b/core/store/src/trie/shard_tries.rs index 92977af9469..d804b49dcb4 100644 --- a/core/store/src/trie/shard_tries.rs +++ b/core/store/src/trie/shard_tries.rs @@ -286,7 +286,7 @@ impl ShardTries { level = "trace", target = "store::trie::shard_tries", "ShardTries::apply_insertions", - fields(num_insertions = trie_changes.insertions().len(), shard_id = shard_uid.shard_id()), + fields(num_insertions = trie_changes.insertions().len(), shard_id = ?shard_uid.shard_id()), skip_all, )] pub fn apply_insertions( @@ -309,7 +309,7 @@ impl ShardTries { level = "trace", target = "store::trie::shard_tries", "ShardTries::apply_deletions", - fields(num_deletions = trie_changes.deletions().len(), shard_id = shard_uid.shard_id()), + fields(num_deletions = trie_changes.deletions().len(), shard_id = ?shard_uid.shard_id()), skip_all, )] pub fn apply_deletions( @@ -553,7 +553,7 @@ impl WrappedTrieChanges { level = "debug", target = "store::trie::shard_tries", "ShardTries::state_changes_into", - fields(num_state_changes = self.state_changes.len(), shard_id = self.shard_uid.shard_id()), + fields(num_state_changes = self.state_changes.len(), shard_id = ?self.shard_uid.shard_id()), skip_all, )] pub fn state_changes_into(&mut self, store_update: &mut TrieStoreUpdateAdapter) { diff --git a/core/store/src/trie/state_parts.rs b/core/store/src/trie/state_parts.rs index 47b54740534..1f20c3c5ea0 100644 --- a/core/store/src/trie/state_parts.rs +++ b/core/store/src/trie/state_parts.rs @@ -130,7 +130,7 @@ impl Trie { ) -> Result<(PartialState, Vec, Vec), StorageError> { let shard_id: ShardId = self.flat_storage_chunk_view.as_ref().map_or( ShardId::MAX, // Fake value for metrics. - |chunk_view| chunk_view.shard_uid().shard_id as ShardId, + |chunk_view| chunk_view.shard_uid().shard_id(), ); let _span = tracing::debug_span!( target: "state-parts", @@ -184,7 +184,7 @@ impl Trie { ) -> Result { let shard_id: ShardId = self.flat_storage_chunk_view.as_ref().map_or( ShardId::MAX, // Fake value for metrics. - |chunk_view| chunk_view.shard_uid().shard_id as ShardId, + |chunk_view| chunk_view.shard_uid().shard_id(), ); let _span = tracing::debug_span!( target: "state-parts", diff --git a/core/store/src/trie/trie_storage.rs b/core/store/src/trie/trie_storage.rs index 61ac07fa9b0..51a7b048a4d 100644 --- a/core/store/src/trie/trie_storage.rs +++ b/core/store/src/trie/trie_storage.rs @@ -612,7 +612,7 @@ mod trie_cache_tests { use crate::{StoreConfig, TrieCache, TrieConfig}; use near_primitives::hash::hash; use near_primitives::shard_layout::ShardUId; - use near_primitives::types::ShardId; + use near_primitives::types::{new_shard_id_tmp, shard_id_as_u32, ShardId}; fn put_value(cache: &mut TrieCacheInner, value: &[u8]) { cache.put(hash(value), value.into()); @@ -622,7 +622,8 @@ mod trie_cache_tests { fn test_size_limit() { let value_size_sum = 5; let memory_overhead = 2 * TrieCacheInner::PER_ENTRY_OVERHEAD; - let mut cache = TrieCacheInner::new(100, value_size_sum + memory_overhead, 0, false); + let mut cache = + TrieCacheInner::new(100, value_size_sum + memory_overhead, new_shard_id_tmp(0), false); // Add three values. Before each put, condition on total size should not be triggered. put_value(&mut cache, &[1, 1]); assert_eq!(cache.current_total_size(), 2 + TrieCacheInner::PER_ENTRY_OVERHEAD); @@ -640,7 +641,7 @@ mod trie_cache_tests { #[test] fn test_deletions_queue() { - let mut cache = TrieCacheInner::new(2, 1000, 0, false); + let mut cache = TrieCacheInner::new(2, 1000, new_shard_id_tmp(0), false); // Add two values to the cache. put_value(&mut cache, &[1]); put_value(&mut cache, &[1, 1]); @@ -659,7 +660,7 @@ mod trie_cache_tests { fn test_cache_capacity() { let capacity = 2; let total_size_limit = TrieCacheInner::PER_ENTRY_OVERHEAD * capacity; - let mut cache = TrieCacheInner::new(100, total_size_limit, 0, false); + let mut cache = TrieCacheInner::new(100, total_size_limit, new_shard_id_tmp(0), false); put_value(&mut cache, &[1]); put_value(&mut cache, &[2]); put_value(&mut cache, &[3]); @@ -672,7 +673,7 @@ mod trie_cache_tests { #[test] fn test_small_memory_limit() { let total_size_limit = 1; - let mut cache = TrieCacheInner::new(100, total_size_limit, 0, false); + let mut cache = TrieCacheInner::new(100, total_size_limit, new_shard_id_tmp(0), false); put_value(&mut cache, &[1, 2, 3]); put_value(&mut cache, &[2, 3, 4]); put_value(&mut cache, &[3, 4, 5]); @@ -699,10 +700,10 @@ mod trie_cache_tests { store_config.view_trie_cache.per_shard_max_bytes.insert(s0, S0_VIEW_SIZE); let trie_config = TrieConfig::from_store_config(&store_config); - check_cache_size(&trie_config, 1, false, DEFAULT_SIZE); - check_cache_size(&trie_config, 0, false, S0_SIZE); - check_cache_size(&trie_config, 1, true, DEFAULT_VIEW_SIZE); - check_cache_size(&trie_config, 0, true, S0_VIEW_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(1), false, DEFAULT_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(0), false, S0_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(1), true, DEFAULT_VIEW_SIZE); + check_cache_size(&trie_config, new_shard_id_tmp(0), true, S0_VIEW_SIZE); } #[track_caller] @@ -712,7 +713,7 @@ mod trie_cache_tests { is_view: bool, expected_size: bytesize::ByteSize, ) { - let shard_uid = ShardUId { version: 0, shard_id: shard_id as u32 }; + let shard_uid = ShardUId { version: 0, shard_id: shard_id_as_u32(shard_id) }; let trie_cache = TrieCache::new(&trie_config, shard_uid, is_view); assert_eq!(expected_size.as_u64(), trie_cache.lock().total_size_limit); assert_eq!(is_view, trie_cache.lock().is_view); diff --git a/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs b/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs index b5b1fd18f65..5f1695cef89 100644 --- a/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs +++ b/genesis-tools/genesis-csv-to-json/src/csv_to_json_configs.rs @@ -5,7 +5,7 @@ use near_chain_configs::{ MIN_GAS_PRICE, NEAR_BASE, NUM_BLOCKS_PER_YEAR, NUM_BLOCK_PRODUCER_SEATS, PROTOCOL_REWARD_RATE, PROTOCOL_UPGRADE_STAKE_THRESHOLD, TRANSACTION_VALIDITY_PERIOD, }; -use near_primitives::types::{Balance, NumShards, ShardId}; +use near_primitives::types::{new_shard_id_tmp, Balance, NumShards, ShardId}; use near_primitives::utils::get_num_seats_per_shard; use near_primitives::version::PROTOCOL_VERSION; use nearcore::config::{Config, CONFIG_FILENAME, NODE_KEY_FILE}; @@ -14,7 +14,16 @@ use std::fs::File; use std::path::Path; const ACCOUNTS_FILE: &str = "accounts.csv"; -const SHARDS: &'static [ShardId] = &[0, 1, 2, 3, 4, 5, 6, 7]; +const SHARDS: &'static [ShardId] = &[ + new_shard_id_tmp(0), + new_shard_id_tmp(1), + new_shard_id_tmp(2), + new_shard_id_tmp(3), + new_shard_id_tmp(4), + new_shard_id_tmp(5), + new_shard_id_tmp(6), + new_shard_id_tmp(7), +]; fn verify_total_supply(total_supply: Balance, chain_id: &str) { if chain_id == near_primitives::chains::MAINNET { diff --git a/genesis-tools/genesis-csv-to-json/src/main.rs b/genesis-tools/genesis-csv-to-json/src/main.rs index 4553970cfe8..7bd9b40f229 100644 --- a/genesis-tools/genesis-csv-to-json/src/main.rs +++ b/genesis-tools/genesis-csv-to-json/src/main.rs @@ -35,7 +35,7 @@ fn main() { if s.is_empty() { HashSet::default() } else { - s.split(',').map(|v| v.parse::().unwrap()).collect() + s.split(',').map(|v| v.parse::().unwrap().into()).collect() } } None => HashSet::default(), diff --git a/genesis-tools/genesis-populate/src/lib.rs b/genesis-tools/genesis-populate/src/lib.rs index 905dc8fa469..ee223645c5b 100644 --- a/genesis-tools/genesis-populate/src/lib.rs +++ b/genesis-tools/genesis-populate/src/lib.rs @@ -18,7 +18,10 @@ use near_primitives::hash::{hash, CryptoHash}; use near_primitives::shard_layout::{account_id_to_shard_id, ShardUId}; use near_primitives::state_record::StateRecord; use near_primitives::types::chunk_extra::ChunkExtra; -use near_primitives::types::{AccountId, Balance, EpochId, ShardId, StateChangeCause, StateRoot}; +use near_primitives::types::{ + new_shard_id_tmp, shard_id_as_u32, AccountId, Balance, EpochId, ShardId, StateChangeCause, + StateRoot, +}; use near_primitives::utils::to_timestamp; use near_primitives::version::ProtocolFeature; use near_store::adapter::StoreUpdateAdapter; @@ -134,15 +137,19 @@ impl GenesisBuilder { let roots = get_genesis_state_roots(self.runtime.store())? .expect("genesis state roots not initialized."); let genesis_shard_version = self.genesis.config.shard_layout.version(); - self.roots = roots.into_iter().enumerate().map(|(k, v)| (k as u64, v)).collect(); + self.roots = + roots.into_iter().enumerate().map(|(k, v)| (new_shard_id_tmp(k as u64), v)).collect(); self.state_updates = self .roots .iter() - .map(|(shard_idx, root)| { + .map(|(&shard_id, root)| { ( - *shard_idx, + shard_id, self.runtime.get_tries().new_trie_update( - ShardUId { version: genesis_shard_version, shard_id: *shard_idx as u32 }, + ShardUId { + version: genesis_shard_version, + shard_id: shard_id_as_u32(shard_id), + }, *root, ), ) @@ -200,7 +207,8 @@ impl GenesisBuilder { state_update.commit(StateChangeCause::InitialState); let (_, trie_changes, state_changes) = state_update.finalize()?; let genesis_shard_version = self.genesis.config.shard_layout.version(); - let shard_uid = ShardUId { version: genesis_shard_version, shard_id: shard_idx as u32 }; + let shard_uid = + ShardUId { version: genesis_shard_version, shard_id: shard_id_as_u32(shard_idx) }; let mut store_update = tries.store_update(); let root = tries.apply_all(&trie_changes, shard_uid, &mut store_update); near_store::flat::FlatStateChanges::from_state_changes(&state_changes) @@ -300,7 +308,7 @@ impl GenesisBuilder { &self, protocol_version: ProtocolVersion, genesis: &Block, - shard_id: u64, + shard_id: ShardId, state_root: CryptoHash, ) -> Result> { if !ProtocolFeature::CongestionControl.enabled(protocol_version) { diff --git a/integration-tests/src/runtime_utils.rs b/integration-tests/src/runtime_utils.rs index 214dc62046a..94bc7c4957b 100644 --- a/integration-tests/src/runtime_utils.rs +++ b/integration-tests/src/runtime_utils.rs @@ -7,8 +7,8 @@ use near_chain_configs::Genesis; use near_parameters::RuntimeConfig; use near_primitives::shard_layout::ShardUId; use near_primitives::state_record::{state_record_to_account_id, StateRecord}; -use near_primitives::types::AccountId; use near_primitives::types::StateRoot; +use near_primitives::types::{new_shard_id_tmp, AccountId}; use near_primitives_core::types::NumShards; use near_store::genesis::GenesisStateApplier; use near_store::test_utils::TestTriesBuilder; @@ -51,7 +51,7 @@ pub fn get_runtime_and_trie_from_genesis(genesis: &Genesis) -> (Runtime, ShardTr let genesis_root = GenesisStateApplier::apply( &writers, tries.clone(), - ShardUId::from_shard_id_and_layout(0, shard_layout), + ShardUId::from_shard_id_and_layout(new_shard_id_tmp(0), shard_layout), &genesis .config .validators diff --git a/integration-tests/src/test_loop/builder.rs b/integration-tests/src/test_loop/builder.rs index 9c9903b1e75..197faf0c31f 100644 --- a/integration-tests/src/test_loop/builder.rs +++ b/integration-tests/src/test_loop/builder.rs @@ -29,7 +29,7 @@ use near_parameters::RuntimeConfigStore; use near_primitives::epoch_manager::EpochConfigStore; use near_primitives::network::PeerId; use near_primitives::test_utils::create_test_signer; -use near_primitives::types::AccountId; +use near_primitives::types::{new_shard_id_tmp, AccountId}; use near_store::adapter::StoreAdapter; use near_store::config::StateSnapshotType; use near_store::genesis::initialize_genesis_state; @@ -285,7 +285,7 @@ impl TestLoopBuilder { if is_validator && !self.track_all_shards { client_config.tracked_shards = Vec::new(); } else { - client_config.tracked_shards = vec![666]; + client_config.tracked_shards = vec![new_shard_id_tmp(666)]; } if let Some(config_modifier) = &self.config_modifier { diff --git a/integration-tests/src/test_loop/tests/in_memory_tries.rs b/integration-tests/src/test_loop/tests/in_memory_tries.rs index 3d63b5a90f0..6f37d2cf06b 100644 --- a/integration-tests/src/test_loop/tests/in_memory_tries.rs +++ b/integration-tests/src/test_loop/tests/in_memory_tries.rs @@ -3,7 +3,7 @@ use near_async::time::Duration; use near_chain_configs::test_genesis::TestGenesisBuilder; use near_client::test_utils::test_loop::ClientQueries; use near_o11y::testonly::init_test_logger; -use near_primitives::types::AccountId; +use near_primitives::types::{new_shard_id_tmp, AccountId}; use near_store::ShardUId; use crate::test_loop::builder::TestLoopBuilder; @@ -77,7 +77,15 @@ fn test_load_memtrie_after_empty_chunks() { current_tracked_shards .iter() .enumerate() - .find_map(|(idx, shards)| if shards.contains(&0) { Some(idx) } else { None }) + .find_map( + |(idx, shards)| { + if shards.contains(&new_shard_id_tmp(0)) { + Some(idx) + } else { + None + } + }, + ) .expect("Not found any client tracking shard 0") }; @@ -87,11 +95,15 @@ fn test_load_memtrie_after_empty_chunks() { clients[idx] .runtime_adapter .get_tries() - .unload_mem_trie(&ShardUId::from_shard_id_and_layout(0, &shard_layout)); + .unload_mem_trie(&ShardUId::from_shard_id_and_layout(new_shard_id_tmp(0), &shard_layout)); clients[idx] .runtime_adapter .get_tries() - .load_mem_trie(&ShardUId::from_shard_id_and_layout(0, &shard_layout), None, true) + .load_mem_trie( + &ShardUId::from_shard_id_and_layout(new_shard_id_tmp(0), &shard_layout), + None, + true, + ) .expect("Couldn't load memtrie"); // Give the test a chance to finish off remaining events in the event loop, which can diff --git a/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs b/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs index fc8531552c9..f2f4a3dc7d1 100644 --- a/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs +++ b/integration-tests/src/test_loop/tests/view_requests_to_archival_node.rs @@ -15,8 +15,8 @@ use near_network::client::BlockHeadersRequest; use near_o11y::testonly::init_test_logger; use near_primitives::sharding::ChunkHash; use near_primitives::types::{ - AccountId, BlockHeight, BlockId, BlockReference, EpochId, EpochReference, Finality, - SyncCheckpoint, + new_shard_id_tmp, AccountId, BlockHeight, BlockId, BlockReference, EpochId, EpochReference, + Finality, SyncCheckpoint, }; use near_primitives::version::PROTOCOL_VERSION; use near_primitives::views::{ @@ -223,10 +223,10 @@ impl<'a> ViewClientTester<'a> { chunk }; - let chunk_by_height = GetChunk::Height(5, 0); + let chunk_by_height = GetChunk::Height(5, new_shard_id_tmp(0)); get_and_check_chunk(chunk_by_height); - let chunk_by_block_hash = GetChunk::BlockHash(block.header.hash, 0); + let chunk_by_block_hash = GetChunk::BlockHash(block.header.hash, new_shard_id_tmp(0)); get_and_check_chunk(chunk_by_block_hash); let chunk_by_chunk_hash = GetChunk::ChunkHash(ChunkHash(block.chunks[0].chunk_hash)); @@ -242,10 +242,10 @@ impl<'a> ViewClientTester<'a> { assert_eq!(shard_chunk.take_header().gas_limit(), 1_000_000_000_000_000); }; - let chunk_by_height = GetShardChunk::Height(5, 0); + let chunk_by_height = GetShardChunk::Height(5, new_shard_id_tmp(0)); get_and_check_shard_chunk(chunk_by_height); - let chunk_by_block_hash = GetShardChunk::BlockHash(block.header.hash, 0); + let chunk_by_block_hash = GetShardChunk::BlockHash(block.header.hash, new_shard_id_tmp(0)); get_and_check_shard_chunk(chunk_by_block_hash); let chunk_by_chunk_hash = GetShardChunk::ChunkHash(ChunkHash(block.chunks[0].chunk_hash)); @@ -376,9 +376,9 @@ impl<'a> ViewClientTester<'a> { let request = GetExecutionOutcomesForBlock { block_hash: block.header.hash }; let outcomes = self.send(request, ARCHIVAL_CLIENT).unwrap(); assert_eq!(outcomes.len(), NUM_SHARDS); - assert_eq!(outcomes[&0].len(), 1); + assert_eq!(outcomes[&new_shard_id_tmp(0)].len(), 1); assert!(matches!( - outcomes[&0][0], + outcomes[&new_shard_id_tmp(0)][0], ExecutionOutcomeWithIdView { outcome: ExecutionOutcomeView { status: ExecutionStatusView::SuccessReceiptId(_), @@ -387,9 +387,9 @@ impl<'a> ViewClientTester<'a> { .. } )); - assert_eq!(outcomes[&1].len(), 1); + assert_eq!(outcomes[&new_shard_id_tmp(1)].len(), 1); assert!(matches!( - outcomes[&1][0], + outcomes[&new_shard_id_tmp(1)][0], ExecutionOutcomeWithIdView { outcome: ExecutionOutcomeView { status: ExecutionStatusView::SuccessReceiptId(_), @@ -398,8 +398,8 @@ impl<'a> ViewClientTester<'a> { .. } )); - assert_eq!(outcomes[&2].len(), 0); - assert_eq!(outcomes[&3].len(), 0); + assert_eq!(outcomes[&new_shard_id_tmp(2)].len(), 0); + assert_eq!(outcomes[&new_shard_id_tmp(3)].len(), 0); } /// Generates variations of the [`GetStateChanges`] request and issues them to the view client of the archival node. diff --git a/integration-tests/src/tests/client/block_corruption.rs b/integration-tests/src/tests/client/block_corruption.rs index d45f9bc579e..e5ffaf651b7 100644 --- a/integration-tests/src/tests/client/block_corruption.rs +++ b/integration-tests/src/tests/client/block_corruption.rs @@ -61,16 +61,17 @@ fn change_shard_id_to_invalid() { let mut block = env.clients[0].produce_block(2).unwrap().unwrap(); // 1. Corrupt chunks + let bad_shard_id = 100; let mut new_chunks = vec![]; for chunk in block.chunks().iter() { let mut new_chunk = chunk.clone(); match &mut new_chunk { - ShardChunkHeader::V1(new_chunk) => new_chunk.inner.shard_id = 100, - ShardChunkHeader::V2(new_chunk) => new_chunk.inner.shard_id = 100, + ShardChunkHeader::V1(new_chunk) => new_chunk.inner.shard_id = bad_shard_id, + ShardChunkHeader::V2(new_chunk) => new_chunk.inner.shard_id = bad_shard_id, ShardChunkHeader::V3(new_chunk) => match &mut new_chunk.inner { - ShardChunkHeaderInner::V1(inner) => inner.shard_id = 100, - ShardChunkHeaderInner::V2(inner) => inner.shard_id = 100, - ShardChunkHeaderInner::V3(inner) => inner.shard_id = 100, + ShardChunkHeaderInner::V1(inner) => inner.shard_id = bad_shard_id, + ShardChunkHeaderInner::V2(inner) => inner.shard_id = bad_shard_id, + ShardChunkHeaderInner::V3(inner) => inner.shard_id = bad_shard_id, }, }; new_chunks.push(new_chunk); @@ -88,7 +89,8 @@ fn change_shard_id_to_invalid() { // Try to process corrupt block and expect code to notice invalid shard_id let res = env.clients[0].process_block_test(block.into(), Provenance::NONE); match res { - Err(Error::InvalidShardId(100)) => { + Err(Error::InvalidShardId(shard_id)) => { + assert_eq!(shard_id, bad_shard_id); tracing::debug!("process failed successfully"); } Err(e) => { diff --git a/integration-tests/src/tests/client/challenges.rs b/integration-tests/src/tests/client/challenges.rs index 7338208fd99..9b1072384ee 100644 --- a/integration-tests/src/tests/client/challenges.rs +++ b/integration-tests/src/tests/client/challenges.rs @@ -22,7 +22,7 @@ use near_primitives::stateless_validation::chunk_endorsement::ChunkEndorsementV1 use near_primitives::test_utils::create_test_signer; use near_primitives::transaction::SignedTransaction; use near_primitives::types::chunk_extra::ChunkExtra; -use near_primitives::types::AccountId; +use near_primitives::types::{AccountId, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::Trie; use nearcore::test_utils::TestEnvNightshadeSetupExt; @@ -200,7 +200,7 @@ fn test_verify_chunk_invalid_proofs_challenge() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_eq!(challenge_result.unwrap(), (*block.hash(), vec!["test0".parse().unwrap()])); } @@ -215,7 +215,7 @@ fn test_verify_chunk_invalid_proofs_challenge_decoded_chunk() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Decoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Decoded(chunk).into(), &block); assert_eq!(challenge_result.unwrap(), (*block.hash(), vec!["test0".parse().unwrap()])); } @@ -228,7 +228,7 @@ fn test_verify_chunk_proofs_malicious_challenge_no_changes() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_matches!(challenge_result.unwrap_err(), Error::MaliciousChallenge); } @@ -265,7 +265,7 @@ fn test_verify_chunk_proofs_malicious_challenge_valid_order_transactions() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_matches!(challenge_result.unwrap_err(), Error::MaliciousChallenge); } @@ -302,13 +302,13 @@ fn test_verify_chunk_proofs_challenge_transaction_order() { let shard_id = chunk.shard_id(); let challenge_result = - challenge(env, shard_id as usize, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); + challenge(env, shard_id, MaybeEncodedShardChunk::Encoded(chunk).into(), &block); assert_eq!(challenge_result.unwrap(), (*block.hash(), vec!["test0".parse().unwrap()])); } fn challenge( env: TestEnv, - shard_id: usize, + shard_id: ShardId, chunk: Box, block: &Block, ) -> Result<(CryptoHash, Vec), Error> { @@ -317,7 +317,7 @@ fn challenge( ChallengeBody::ChunkProofs(ChunkProofs { block_header: borsh::to_vec(&block.header()).unwrap(), chunk, - merkle_proof: merkle_paths[shard_id].clone(), + merkle_proof: merkle_paths[shard_id as usize].clone(), }), &*env.clients[0].validator_signer.get().unwrap(), ); diff --git a/nearcore/src/config.rs b/nearcore/src/config.rs index 57a2d8d1fac..30a2607adf1 100644 --- a/nearcore/src/config.rs +++ b/nearcore/src/config.rs @@ -40,8 +40,8 @@ use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::ShardLayout; use near_primitives::test_utils::create_test_signer; use near_primitives::types::{ - AccountId, AccountInfo, Balance, BlockHeight, BlockHeightDelta, Gas, NumSeats, NumShards, - ShardId, + new_shard_id_tmp, AccountId, AccountInfo, Balance, BlockHeight, BlockHeightDelta, Gas, + NumSeats, NumShards, ShardId, }; use near_primitives::utils::{from_timestamp, get_num_seats_per_shard}; use near_primitives::validator_signer::{InMemoryValidatorSigner, ValidatorSigner}; @@ -845,7 +845,7 @@ pub fn init_configs( let mut config = Config::default(); // Make sure node tracks all shards, see // https://github.com/near/nearcore/issues/7388 - config.tracked_shards = vec![0]; + config.tracked_shards = vec![new_shard_id_tmp(0)]; // If a config gets generated, block production times may need to be updated. set_block_production_delay(&chain_id, fast, &mut config); @@ -1073,7 +1073,7 @@ pub fn create_localnet_configs_from_seeds( num_non_validators_archival: NumSeats, num_non_validators_rpc: NumSeats, num_non_validators: NumSeats, - tracked_shards: Vec, + tracked_shards: Vec, ) -> (Vec, Vec, Vec, Genesis) { assert_eq!( seeds.len() as u64, @@ -1163,7 +1163,7 @@ pub fn create_localnet_configs_from_seeds( fn create_localnet_config( num_shards: NumShards, num_validators: NumSeats, - tracked_shards: &Vec, + tracked_shards: &Vec, network_signers: &Vec, boot_node_addr: &tcp::ListenerAddr, params: LocalnetNodeParams, @@ -1204,7 +1204,7 @@ fn create_localnet_config( // Make non-validator archival and RPC nodes track all shards. // Note that validator nodes may track all or some of the shards. config.tracked_shards = if !params.is_validator && (params.is_archival || params.is_rpc) { - (0..num_shards).collect() + (0..num_shards).map(new_shard_id_tmp).collect() } else { tracked_shards.clone() }; @@ -1232,7 +1232,7 @@ pub fn create_localnet_configs( num_non_validators_rpc: NumSeats, num_non_validators: NumSeats, prefix: &str, - tracked_shards: Vec, + tracked_shards: Vec, ) -> (Vec, Vec, Vec, Genesis, Vec) { let num_all_nodes = num_validators + num_non_validators_archival + num_non_validators_rpc + num_non_validators; @@ -1272,7 +1272,7 @@ pub fn init_localnet_configs( num_non_validators_rpc: NumSeats, num_non_validators: NumSeats, prefix: &str, - tracked_shards: Vec, + tracked_shards: Vec, ) { let (configs, validator_signers, network_signers, genesis, shard_keys) = create_localnet_configs( @@ -1522,7 +1522,7 @@ mod tests { use near_chain_configs::{GCConfig, Genesis, GenesisValidationMode}; use near_crypto::InMemorySigner; use near_primitives::shard_layout::account_id_to_shard_id; - use near_primitives::types::{AccountId, NumShards}; + use near_primitives::types::{new_shard_id_tmp, AccountId, NumShards}; use tempfile::tempdir; use crate::config::{ @@ -1562,21 +1562,21 @@ mod tests { &AccountId::from_str("foobar.near").unwrap(), &genesis.config.shard_layout, ), - 0 + new_shard_id_tmp(0) ); assert_eq!( account_id_to_shard_id( &AccountId::from_str("shard1.test.near").unwrap(), &genesis.config.shard_layout, ), - 1 + new_shard_id_tmp(1) ); assert_eq!( account_id_to_shard_id( &AccountId::from_str("shard2.test.near").unwrap(), &genesis.config.shard_layout, ), - 2 + new_shard_id_tmp(2) ); } @@ -1704,7 +1704,7 @@ mod tests { let prefix = "node"; // Validators will track single shard but archival and RPC nodes will track all shards. - let empty_tracked_shards: Vec = vec![]; + let empty_tracked_shards = vec![]; let (configs, _validator_signers, _network_signers, genesis, _shard_keys) = create_localnet_configs( @@ -1746,7 +1746,10 @@ mod tests { config.split_storage.clone().unwrap().enable_split_storage_view_client, true ); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check non-validator RPC nodes. @@ -1755,7 +1758,10 @@ mod tests { assert_eq!(config.archive, false); assert!(config.cold_store.is_none()); assert!(config.split_storage.is_none()); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check other non-validator nodes. @@ -1781,7 +1787,7 @@ mod tests { let prefix = "node"; // Validators will track 2 shards and non-validators will track all shards. - let tracked_shards: Vec = vec![1, 3]; + let tracked_shards = vec![new_shard_id_tmp(1), new_shard_id_tmp(3)]; let (configs, _validator_signers, _network_signers, genesis, _shard_keys) = create_localnet_configs( @@ -1823,7 +1829,10 @@ mod tests { config.split_storage.clone().unwrap().enable_split_storage_view_client, true ); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check non-validator RPC nodes. @@ -1832,7 +1841,10 @@ mod tests { assert_eq!(config.archive, false); assert!(config.cold_store.is_none()); assert!(config.split_storage.is_none()); - assert_eq!(config.tracked_shards, (0..num_shards).collect::>()); + assert_eq!( + config.tracked_shards, + (0..num_shards).map(new_shard_id_tmp).collect::>() + ); } // Check other non-validator nodes. diff --git a/nearcore/src/config_validate.rs b/nearcore/src/config_validate.rs index 3db950d9607..a554563680a 100644 --- a/nearcore/src/config_validate.rs +++ b/nearcore/src/config_validate.rs @@ -171,6 +171,8 @@ impl<'a> ConfigValidator<'a> { #[cfg(test)] mod tests { + use near_primitives::types::new_shard_id_tmp; + use super::*; #[test] @@ -179,7 +181,7 @@ mod tests { let mut config = Config::default(); config.gc.gc_blocks_limit = 0; // set tracked_shards to be non-empty - config.tracked_shards.push(20); + config.tracked_shards.push(new_shard_id_tmp(20)); validate_config(&config).unwrap(); } @@ -192,7 +194,7 @@ mod tests { config.archive = false; config.save_trie_changes = Some(false); // set tracked_shards to be non-empty - config.tracked_shards.push(20); + config.tracked_shards.push(new_shard_id_tmp(20)); validate_config(&config).unwrap(); } @@ -206,7 +208,7 @@ mod tests { config.save_trie_changes = Some(false); config.gc.gc_blocks_limit = 0; // set tracked_shards to be non-empty - config.tracked_shards.push(20); + config.tracked_shards.push(new_shard_id_tmp(20)); validate_config(&config).unwrap(); } diff --git a/nearcore/src/entity_debug.rs b/nearcore/src/entity_debug.rs index 6f297abced2..ec633b4f3b7 100644 --- a/nearcore/src/entity_debug.rs +++ b/nearcore/src/entity_debug.rs @@ -261,9 +261,11 @@ impl EntityDebugHandlerImpl { let shard_layout = self .epoch_manager .get_shard_layout_from_prev_block(&chunk.cloned_header().prev_block_hash())?; + let shard_id = chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = shard_layout .shard_uids() - .nth(chunk.shard_id() as usize) + .nth(shard_index) .ok_or_else(|| anyhow!("Shard {} not found", chunk.shard_id()))?; let node = store .get_ser::( @@ -299,9 +301,11 @@ impl EntityDebugHandlerImpl { } EntityQuery::ShardUIdByShardId { shard_id, epoch_id } => { let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?; + let shard_index = shard_layout.get_shard_index(shard_id); + let shard_uid = shard_layout .shard_uids() - .nth(shard_id as usize) + .nth(shard_index) .ok_or_else(|| anyhow!("Shard {} not found", shard_id))?; Ok(serialize_entity(&shard_uid)) } @@ -380,9 +384,11 @@ impl EntityDebugHandlerImpl { let shard_layout = self .epoch_manager .get_shard_layout_from_prev_block(&chunk.cloned_header().prev_block_hash())?; + let shard_id = chunk.shard_id(); + let shard_index = shard_layout.get_shard_index(shard_id); let shard_uid = shard_layout .shard_uids() - .nth(chunk.shard_id() as usize) + .nth(shard_index) .ok_or_else(|| anyhow!("Shard {} not found", chunk.shard_id()))?; let path = TriePath { path: vec![], shard_uid, state_root: chunk.prev_state_root() }; diff --git a/nearcore/src/metrics.rs b/nearcore/src/metrics.rs index f8cb0d85362..a8db65cc675 100644 --- a/nearcore/src/metrics.rs +++ b/nearcore/src/metrics.rs @@ -8,6 +8,7 @@ use near_o11y::metrics::{ try_create_int_gauge, try_create_int_gauge_vec, HistogramVec, IntCounterVec, IntGauge, IntGaugeVec, }; +use near_primitives::types::ShardId; use near_primitives::{shard_layout::ShardLayout, state_record::StateRecord, trie_key}; use near_store::adapter::StoreAdapter; use near_store::{ShardUId, Store, Trie, TrieDBStorage}; @@ -148,7 +149,7 @@ fn export_postponed_receipt_count(near_config: &NearConfig, store: &Store) -> an } fn get_postponed_receipt_count_for_shard( - shard_id: u64, + shard_id: ShardId, shard_layout: &ShardLayout, chain_store: &ChainStore, block: &Block, diff --git a/nearcore/src/state_sync.rs b/nearcore/src/state_sync.rs index 73c3ed5feff..75ee9e2b11e 100644 --- a/nearcore/src/state_sync.rs +++ b/nearcore/src/state_sync.rs @@ -261,17 +261,17 @@ fn get_current_state( epoch_height: new_epoch_height, sync_hash: new_sync_hash, } = latest_epoch_info.map_err(|err| { - tracing::error!(target: "state_sync_dump", shard_id, ?err, "Failed to get the latest epoch"); + tracing::error!(target: "state_sync_dump", ?shard_id, ?err, "Failed to get the latest epoch"); err })?; if Some(&new_epoch_id) == was_last_epoch_done.as_ref() { - tracing::debug!(target: "state_sync_dump", shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "latest epoch is done. No new epoch to dump. Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "latest epoch is done. No new epoch to dump. Idle"); Ok(StateDumpAction::Wait) } else if epoch_manager.get_shard_layout(&prev_epoch_id) != epoch_manager.get_shard_layout(&new_epoch_id) { - tracing::debug!(target: "state_sync_dump", shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Shard layout change detected, will skip dumping for this epoch. Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?was_last_epoch_done, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Shard layout change detected, will skip dumping for this epoch. Idle"); chain.chain_store().set_state_sync_dump_progress( *shard_id, Some(StateSyncDumpProgress::Skipped { @@ -287,7 +287,7 @@ fn get_current_state( sync_hash: new_sync_hash, }) } else { - tracing::debug!(target: "state_sync_dump", shard_id, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Doesn't care about the shard in the current epoch. Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?new_epoch_id, new_epoch_height, ?new_sync_hash, "Doesn't care about the shard in the current epoch. Idle"); Ok(StateDumpAction::Wait) } } @@ -313,11 +313,11 @@ async fn upload_state_header( external_storage_location(&chain_id, &epoch_id, epoch_height, shard_id, &file_type); match external.put_file(file_type, &header, shard_id, &location).await { Err(err) => { - tracing::warn!(target: "state_sync_dump", shard_id, epoch_height, ?err, "Failed to put header into external storage. Will retry next iteration."); + tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, ?err, "Failed to put header into external storage. Will retry next iteration."); false } Ok(_) => { - tracing::trace!(target: "state_sync_dump", shard_id, epoch_height, "Header saved to external storage."); + tracing::trace!(target: "state_sync_dump", ?shard_id, epoch_height, "Header saved to external storage."); true } } @@ -341,17 +341,17 @@ async fn state_sync_dump( validator: MutableValidatorSigner, keep_running: Arc, ) { - tracing::info!(target: "state_sync_dump", shard_id, "Running StateSyncDump loop"); + tracing::info!(target: "state_sync_dump", ?shard_id, "Running StateSyncDump loop"); if restart_dump_for_shards.contains(&shard_id) { - tracing::debug!(target: "state_sync_dump", shard_id, "Dropped existing progress"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Dropped existing progress"); chain.chain_store().set_state_sync_dump_progress(shard_id, None).unwrap(); } // Stop if the node is stopped. // Note that without this check the state dumping thread is unstoppable, i.e. non-interruptable. while keep_running.load(std::sync::atomic::Ordering::Relaxed) { - tracing::debug!(target: "state_sync_dump", shard_id, "Running StateSyncDump loop iteration"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Running StateSyncDump loop iteration"); let account_id = validator.get().map(|v| v.validator_id().clone()); let current_state = get_current_state( &chain, @@ -370,7 +370,7 @@ async fn state_sync_dump( let in_progress_data = get_in_progress_data(shard_id, sync_hash, &chain); match in_progress_data { Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ? shard_id, "Failed to get in progress data"); + tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to get in progress data"); None } Ok((state_root, num_parts, sync_prev_prev_hash)) => { @@ -472,7 +472,7 @@ async fn state_sync_dump( let state_part = match state_part { Ok(state_part) => state_part, Err(err) => { - tracing::warn!(target: "state_sync_dump", shard_id, epoch_height, part_id, ?err, "Failed to obtain and store part. Will skip this part."); + tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, part_id, ?err, "Failed to obtain and store part. Will skip this part."); failures_cnt += 1; continue; } @@ -492,7 +492,7 @@ async fn state_sync_dump( { // no need to break if there's an error, we should keep dumping other parts. // reason is we are dumping random selected parts, so it's fine if we are not able to finish all of them - tracing::warn!(target: "state_sync_dump", shard_id, epoch_height, part_id, ?err, "Failed to put a store part into external storage. Will skip this part."); + tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, part_id, ?err, "Failed to put a store part into external storage. Will skip this part."); failures_cnt += 1; continue; } @@ -540,19 +540,19 @@ async fn state_sync_dump( // Record the next state of the state machine. let has_progress = match next_state { Some(next_state) => { - tracing::debug!(target: "state_sync_dump", shard_id, ?next_state); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?next_state); match chain.chain_store().set_state_sync_dump_progress(shard_id, Some(next_state)) { Ok(_) => true, Err(err) => { // This will be retried. - tracing::debug!(target: "state_sync_dump", shard_id, ?err, "Failed to set progress"); + tracing::debug!(target: "state_sync_dump", ?shard_id, ?err, "Failed to set progress"); false } } } None => { // Nothing to do, will check again later. - tracing::debug!(target: "state_sync_dump", shard_id, "Idle"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Idle"); false } }; @@ -562,7 +562,7 @@ async fn state_sync_dump( clock.sleep(iteration_delay).await; } } - tracing::debug!(target: "state_sync_dump", shard_id, "Stopped state dump thread"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Stopped state dump thread"); } // Extracts extra data needed for obtaining state parts. @@ -658,7 +658,7 @@ fn get_latest_epoch( epoch_manager: Arc, ) -> Result { let head = chain.head()?; - tracing::debug!(target: "state_sync_dump", shard_id, "Check if a new complete epoch is available"); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Check if a new complete epoch is available"); let hash = head.last_block_hash; let header = chain.get_block_header(&hash)?; let final_hash = header.last_final_block(); diff --git a/runtime/runtime/src/balance_checker.rs b/runtime/runtime/src/balance_checker.rs index 9bdbdc67f99..2baeabfadff 100644 --- a/runtime/runtime/src/balance_checker.rs +++ b/runtime/runtime/src/balance_checker.rs @@ -12,7 +12,7 @@ use near_primitives::hash::CryptoHash; use near_primitives::receipt::{Receipt, ReceiptEnum, ReceiptOrStateStoredReceipt}; use near_primitives::transaction::SignedTransaction; use near_primitives::trie_key::TrieKey; -use near_primitives::types::{AccountId, Balance}; +use near_primitives::types::{AccountId, Balance, ShardId}; use near_store::trie::receipts_column_helper::{ShardsOutgoingReceiptBuffer, TrieQueue}; use near_store::{ get, get_account, get_postponed_receipt, get_promise_yield_receipt, Trie, TrieAccess, @@ -141,7 +141,7 @@ fn buffered_receipts( let mut forwarded_receipts: Vec = vec![]; let mut new_buffered_receipts: Vec = vec![]; - let mut shards: BTreeSet = BTreeSet::new(); + let mut shards: BTreeSet = BTreeSet::new(); shards.extend(initial_buffers.shards().iter()); shards.extend(final_buffers.shards().iter()); for shard_id in shards { @@ -400,7 +400,7 @@ mod tests { }; use near_primitives::test_utils::account_new; use near_primitives::transaction::{Action, TransferAction}; - use near_primitives::types::{MerkleHash, StateChangeCause}; + use near_primitives::types::{new_shard_id_tmp, MerkleHash, StateChangeCause}; use near_store::test_utils::TestTriesBuilder; use near_store::{set, set_account, Trie}; use testlib::runtime_utils::{alice_account, bob_account}; @@ -706,14 +706,15 @@ mod tests { // create buffer with already a receipt in it, but a different balance let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 1 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 1 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 0 }, &existing_receipt, ); }, @@ -727,14 +728,15 @@ mod tests { // store receipt with the balance in the receipt buffer let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 2 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 1 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 1 }, &new_receipt, ); }, @@ -776,31 +778,36 @@ mod tests { |trie_update| { // store 2 receipts with balance in the receipt buffer let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 2 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 0 }, &receipt0, ); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 1 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 1 }, &receipt1, ); }, |trie_update| { // remove 1 receipt at index 0 let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 1, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 1, next_available_index: 2 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); - trie_update.remove(TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }); + trie_update.remove(TrieKey::BufferedReceipt { + receiving_shard: new_shard_id_tmp(0), + index: 0, + }); }, ); @@ -834,31 +841,36 @@ mod tests { |trie_update| { // store receipt0 with balance in the receipt buffer let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 0, next_available_index: 1 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 0, next_available_index: 1 }, + ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 0 }, &receipt0, ); }, |trie_update| { // pop receipt0 and push receipt1 with a different balance let mut indices = BufferedReceiptIndices::default(); - indices - .shard_buffers - .insert(0, TrieQueueIndices { first_index: 1, next_available_index: 2 }); + indices.shard_buffers.insert( + new_shard_id_tmp(0), + TrieQueueIndices { first_index: 1, next_available_index: 2 }, + ); set( trie_update, - TrieKey::BufferedReceipt { receiving_shard: 0, index: 1 }, + TrieKey::BufferedReceipt { receiving_shard: new_shard_id_tmp(0), index: 1 }, &receipt1, ); set(trie_update, TrieKey::BufferedReceiptIndices, &indices); - trie_update.remove(TrieKey::BufferedReceipt { receiving_shard: 0, index: 0 }); + trie_update.remove(TrieKey::BufferedReceipt { + receiving_shard: new_shard_id_tmp(0), + index: 0, + }); }, ); diff --git a/runtime/runtime/src/congestion_control.rs b/runtime/runtime/src/congestion_control.rs index dbbd1228a20..6f0b356e72b 100644 --- a/runtime/runtime/src/congestion_control.rs +++ b/runtime/runtime/src/congestion_control.rs @@ -9,7 +9,7 @@ use near_primitives::receipt::{ Receipt, ReceiptEnum, ReceiptOrStateStoredReceipt, StateStoredReceipt, StateStoredReceiptMetadata, }; -use near_primitives::types::{EpochInfoProvider, Gas, ShardId}; +use near_primitives::types::{new_shard_id_tmp, EpochInfoProvider, Gas, ShardId}; use near_primitives::version::ProtocolFeature; use near_store::trie::receipts_column_helper::{ DelayedReceiptQueue, ReceiptIterator, ShardsOutgoingReceiptBuffer, TrieQueue, @@ -193,7 +193,7 @@ impl ReceiptSinkV2<'_> { fn forward_from_buffer_to_shard( &mut self, - shard_id: u64, + shard_id: ShardId, state_update: &mut TrieUpdate, apply_state: &ApplyState, ) -> Result<(), RuntimeError> { @@ -317,7 +317,7 @@ impl ReceiptSinkV2<'_> { size: u64, gas: u64, state_update: &mut TrieUpdate, - shard: u64, + shard: ShardId, use_state_stored_receipt: bool, ) -> Result<(), RuntimeError> { let receipt = match use_state_stored_receipt { @@ -460,7 +460,7 @@ pub fn bootstrap_congestion_info( // It is also irrelevant, since the bootstrapped value is only used at // the start of applying a chunk on this shard. Other shards will only // see and act on the first congestion info after that. - allowed_shard: shard_id as u16, + allowed_shard: new_shard_id_tmp(shard_id) as u16, })) } diff --git a/runtime/runtime/src/lib.rs b/runtime/runtime/src/lib.rs index 5b9dde14a90..1dba5601085 100644 --- a/runtime/runtime/src/lib.rs +++ b/runtime/runtime/src/lib.rs @@ -39,6 +39,7 @@ use near_primitives::transaction::{ SignedTransaction, TransferAction, }; use near_primitives::trie_key::TrieKey; +use near_primitives::types::new_shard_id_tmp; use near_primitives::types::{ validator_stake::ValidatorStake, AccountId, Balance, BlockHeight, Compute, EpochHeight, EpochId, EpochInfoProvider, Gas, RawStateChangesWithTrieKey, ShardId, StateChangeCause, @@ -1351,7 +1352,7 @@ impl Runtime { { // Note that receipts are restored only on mainnet so restored_receipts will be empty on // other chains. - migration_data.restored_receipts.get(&0u64).cloned().unwrap_or_default() + migration_data.restored_receipts.get(&new_shard_id_tmp(0)).cloned().unwrap_or_default() } else { vec![] }; @@ -2020,7 +2021,10 @@ impl Runtime { delayed_receipts.apply_congestion_changes(congestion_info)?; let all_shards = apply_state.congestion_info.all_shards(); - let congestion_seed = apply_state.block_height.wrapping_add(apply_state.shard_id); + // TODO(wacban) Using non-contiguous shard id here breaks some + // assumptions. The shard index should be used here instead. + let congestion_seed = + apply_state.block_height.wrapping_add(apply_state.shard_id.into()); congestion_info.finalize_allowed_shard( apply_state.shard_id, all_shards.as_slice(), diff --git a/runtime/runtime/src/metrics.rs b/runtime/runtime/src/metrics.rs index 182e8304bc1..7afd7a1f489 100644 --- a/runtime/runtime/src/metrics.rs +++ b/runtime/runtime/src/metrics.rs @@ -800,7 +800,7 @@ pub fn report_recorded_column_sizes(trie: &Trie, apply_state: &ApplyState) { // Tracing span to measure time spent on reporting column sizes. let _span = tracing::debug_span!( target: "runtime", "report_recorded_column_sizes", - shard_id = apply_state.shard_id, + shard_id = ?apply_state.shard_id, block_height = apply_state.block_height) .entered(); diff --git a/tools/state-viewer/src/epoch_info.rs b/tools/state-viewer/src/epoch_info.rs index ac75f749070..cd29783e481 100644 --- a/tools/state-viewer/src/epoch_info.rs +++ b/tools/state-viewer/src/epoch_info.rs @@ -90,13 +90,16 @@ fn display_block_and_chunk_producers( let block_height_range: Range = get_block_height_range(epoch_id, chain_store, epoch_manager)?; let shard_ids = epoch_manager.shard_ids(epoch_id).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).unwrap(); for block_height in block_height_range { let bp = epoch_info.sample_block_producer(block_height); let bp = epoch_info.get_validator(bp).account_id().clone(); let cps: Vec = shard_ids .iter() .map(|&shard_id| { - let cp = epoch_info.sample_chunk_producer(block_height, shard_id).unwrap(); + let cp = epoch_info + .sample_chunk_producer(&shard_layout, shard_id, block_height) + .unwrap(); let cp = epoch_info.get_validator(cp).account_id().clone(); cp.as_str().to_string() }) @@ -274,13 +277,14 @@ fn display_validator_info( println!("Block producer for {} blocks: {bp_for_blocks:?}", bp_for_blocks.len()); let shard_ids = epoch_manager.shard_ids(epoch_id).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id).unwrap(); let cp_for_chunks: Vec<(BlockHeight, ShardId)> = block_height_range .flat_map(|block_height| { shard_ids .iter() .map(|&shard_id| (block_height, shard_id)) .filter(|&(block_height, shard_id)| { - epoch_info.sample_chunk_producer(block_height, shard_id) + epoch_info.sample_chunk_producer(&shard_layout, shard_id, block_height) == Some(*validator_id) }) .collect::>() diff --git a/tools/state-viewer/src/replay_headers.rs b/tools/state-viewer/src/replay_headers.rs index 58c713b9dfa..c830b54aa4b 100644 --- a/tools/state-viewer/src/replay_headers.rs +++ b/tools/state-viewer/src/replay_headers.rs @@ -228,6 +228,8 @@ fn get_block_info( { let block = chain_store.get_block(header.hash())?; let chunks = block.chunks(); + let epoch_id = block.header().epoch_id(); + let shard_layout = epoch_manager.get_shard_layout(epoch_id)?; let endorsement_signatures = block.chunk_endorsements().to_vec(); assert_eq!(endorsement_signatures.len(), chunks.len()); @@ -237,12 +239,12 @@ fn get_block_info( let height = header.height(); let prev_block_epoch_id = epoch_manager.get_epoch_id_from_prev_block(header.prev_hash())?; - for chunk_header in chunks.iter() { - let shard_id = chunk_header.shard_id(); + for (shard_index, chunk_header) in chunks.iter().enumerate() { + let shard_id = shard_layout.get_shard_id(shard_index); let endorsements = &endorsement_signatures[shard_id as usize]; if !chunk_header.is_new_chunk(height) { assert_eq!(endorsements.len(), 0); - bitmap.add_endorsements(shard_id, vec![]); + bitmap.add_endorsements(shard_index, vec![]); } else { let assignments = epoch_manager .get_chunk_validator_assignments( @@ -253,7 +255,7 @@ fn get_block_info( .ordered_chunk_validators(); assert_eq!(endorsements.len(), assignments.len()); bitmap.add_endorsements( - shard_id, + shard_index, endorsements.iter().map(|signature| signature.is_some()).collect_vec(), ); }