Skip to content

Commit

Permalink
feat(resharding) - Make shard ids non-contiguous (near#12181)
Browse files Browse the repository at this point in the history
This is part 1 of adding support for non-contiguous shard ids. The
principle idea is to make ShardId into a newtype so that it's not
possible to use it to index arrays with chunk data. In addition I'm
adding ShardIndex type and a mapping between shard indices and shard ids
so that it's possible to covert one to another as necessary.

The TLDR of this approach is to make the types right, fix compiler
errors and pray to the software gods that things work out.

I am now giving up on trying to make the migration in a single PR.
Instead I am introducing some temporary structures and methods that are
compatible with both approaches. My current plan for the migration is as
follows:
1) Switch to the new ShardId definition. 
2) Fix some number of compilation errors (using the temporary objects) 
3) Switch back to the old definition 
4) PR, review, merge
5) Repeat 1-4 until there are no more errors. 
6) Cleanup the temporary objects
7) Adjust some tests to use the new ShardLayout with non-contiguous
shard ids.
8) Try to get rid of the mapping wherever possible


There are a few common themes in this PR:
* read the shard layout and convert shard id to shard index in order to
use it to index some array or chunk data
* replace enumerate with reading the shard id directly from the chunk
header / other chunk data
* replace using shard id by adding enumerate to get the shard index
* add `?` to shard id in tracing logs because the newtype ShardId
doesn't work without it

must-review files:
* shard_layout.rs
* primitives-core/src/types.rs
* shard_assignment.rs

good-to-review files:
* state_transition_data.rs
  • Loading branch information
wacban authored Oct 10, 2024
1 parent 2d5dd96 commit 99ecfa4
Show file tree
Hide file tree
Showing 105 changed files with 1,593 additions and 893 deletions.
2 changes: 1 addition & 1 deletion chain/chain-primitives/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ pub enum Error {
InvalidBlockMerkleRoot,
/// Invalid split shard ids.
#[error("Invalid Split Shard Ids when resharding. shard_id: {0}, parent_shard_id: {1}")]
InvalidSplitShardsIds(u64, u64),
InvalidSplitShardsIds(ShardId, ShardId),
/// Someone is not a validator. Usually happens in signature verification
#[error("Not A Validator: {0}")]
NotAValidator(String),
Expand Down
13 changes: 10 additions & 3 deletions chain/chain/src/blocks_delay_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use near_async::time::{Clock, Instant, Utc};
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::block::{Block, Tip};
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::ShardLayout;
use near_primitives::sharding::{ChunkHash, ShardChunkHeader};
use near_primitives::types::{BlockHeight, ShardId};
use near_primitives::views::{
Expand Down Expand Up @@ -289,18 +290,24 @@ impl BlocksDelayTracker {
}
}

pub fn finish_block_processing(&mut self, block_hash: &CryptoHash, new_head: Option<Tip>) {
pub fn finish_block_processing(
&mut self,
shard_layout: &ShardLayout,
block_hash: &CryptoHash,
new_head: Option<Tip>,
) {
if let Some(processed_block) = self.blocks.get_mut(&block_hash) {
processed_block.processed_timestamp = Some(self.clock.now());
}
// To get around the rust reference scope check
if let Some(processed_block) = self.blocks.get(&block_hash) {
let chunks = processed_block.chunks.clone();
self.update_block_metrics(processed_block);
for (shard_id, chunk_hash) in chunks.into_iter().enumerate() {
for (shard_index, chunk_hash) in chunks.into_iter().enumerate() {
if let Some(chunk_hash) = chunk_hash {
if let Some(processed_chunk) = self.chunks.get(&chunk_hash) {
self.update_chunk_metrics(processed_chunk, shard_id as ShardId);
let shard_id = shard_layout.get_shard_id(shard_index);
self.update_chunk_metrics(processed_chunk, shard_id);
}
}
}
Expand Down
157 changes: 101 additions & 56 deletions chain/chain/src/chain.rs

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions chain/chain/src/chain_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl<'a> ChainUpdate<'a> {
let prev_hash = block.header().prev_hash();
let results = apply_chunks_results.into_iter().map(|(shard_id, x)| {
if let Err(err) = &x {
warn!(target: "chain", shard_id, hash = %block.hash(), %err, "Error in applying chunk for block");
warn!(target: "chain", ?shard_id, hash = %block.hash(), %err, "Error in applying chunk for block");
}
x
}).collect::<Result<Vec<_>, Error>>()?;
Expand Down Expand Up @@ -465,7 +465,7 @@ impl<'a> ChainUpdate<'a> {
shard_state_header: ShardStateSyncResponseHeader,
) -> Result<ShardUId, Error> {
let _span =
tracing::debug_span!(target: "sync", "chain_update_set_state_finalize", shard_id, ?sync_hash).entered();
tracing::debug_span!(target: "sync", "chain_update_set_state_finalize", ?shard_id, ?sync_hash).entered();
let (chunk, incoming_receipts_proofs) = match shard_state_header {
ShardStateSyncResponseHeader::V1(shard_state_header) => (
ShardChunk::V1(shard_state_header.chunk),
Expand Down Expand Up @@ -596,7 +596,7 @@ impl<'a> ChainUpdate<'a> {
sync_hash: CryptoHash,
) -> Result<bool, Error> {
let _span =
tracing::debug_span!(target: "sync", "set_state_finalize_on_height", height, shard_id)
tracing::debug_span!(target: "sync", "set_state_finalize_on_height", height, ?shard_id)
.entered();
let block_header_result =
self.chain_store_update.get_block_header_on_chain_by_height(&sync_hash, height);
Expand Down
10 changes: 7 additions & 3 deletions chain/chain/src/garbage_collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,11 @@ impl<'a> ChainStoreUpdate<'a> {
let block =
self.get_block(&block_hash).expect("block data is not expected to be already cleaned");
let height = block.header().height();
let epoch_id = block.header().epoch_id();
let shard_layout = epoch_manager.get_shard_layout(epoch_id).expect("epoch id must exist");

// 2. Delete shard_id-indexed data (Receipts, State Headers and Parts, etc.)
for shard_id in 0..block.header().chunk_mask().len() as ShardId {
for shard_id in shard_layout.shard_ids() {
let block_shard_id = get_block_shard_id(&block_hash, shard_id);
self.gc_outgoing_receipts(&block_hash, shard_id);
self.gc_col(DBCol::IncomingReceipts, &block_shard_id);
Expand Down Expand Up @@ -678,11 +680,11 @@ impl<'a> ChainStoreUpdate<'a> {
self.get_block(&block_hash).expect("block data is not expected to be already cleaned");

let epoch_id = block.header().epoch_id();

let head_height = block.header().height();
let shard_layout = epoch_manager.get_shard_layout(epoch_id).expect("epoch id must exist");

// 1. Delete shard_id-indexed data (TrieChanges, Receipts, ChunkExtra, State Headers and Parts, FlatStorage data)
for shard_id in 0..block.header().chunk_mask().len() as ShardId {
for shard_id in shard_layout.shard_ids() {
let shard_uid = epoch_manager.shard_id_to_uid(shard_id, epoch_id).unwrap();
let block_shard_id = get_block_shard_uid(&block_hash, &shard_uid);

Expand Down Expand Up @@ -833,6 +835,8 @@ impl<'a> ChainStoreUpdate<'a> {
for chunk_header in
block.chunks().iter().filter(|h| h.height_included() == block.header().height())
{
// It is ok to use the shard id from the header because it is a new
// chunk. An old chunk may have the shard id from the parent shard.
let shard_id = chunk_header.shard_id();
let outcome_ids =
self.chain_store().get_outcomes_by_block_hash_and_shard_id(block_hash, shard_id)?;
Expand Down
2 changes: 1 addition & 1 deletion chain/chain/src/migrations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn check_if_block_is_first_with_chunk_of_version(
if is_first_epoch_with_protocol_version(epoch_manager, prev_block_hash)? {
// Compare only epochs because we already know that current epoch is the first one with current protocol version
// convert shard id to shard id of previous epoch because number of shards may change
let shard_id = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0];
let (shard_id, _) = epoch_manager.get_prev_shard_ids(prev_block_hash, vec![shard_id])?[0];
let prev_epoch_id = chain_store.get_epoch_id_of_last_block_with_chunk(
epoch_manager,
prev_block_hash,
Expand Down
6 changes: 5 additions & 1 deletion chain/chain/src/runtime/migrations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod tests {
use near_mainnet_res::mainnet_restored_receipts;
use near_mainnet_res::mainnet_storage_usage_delta;
use near_primitives::hash::hash;
use near_primitives::types::new_shard_id_tmp;

#[test]
fn test_migration_data() {
Expand All @@ -55,7 +56,10 @@ mod tests {
"48ZMJukN7RzvyJSW9MJ5XmyQkQFfjy2ZxPRaDMMHqUcT"
);
let mainnet_migration_data = load_migration_data(near_primitives::chains::MAINNET);
assert_eq!(mainnet_migration_data.restored_receipts.get(&0u64).unwrap().len(), 383);
assert_eq!(
mainnet_migration_data.restored_receipts.get(&new_shard_id_tmp(0)).unwrap().len(),
383
);
let testnet_migration_data = load_migration_data(near_primitives::chains::TESTNET);
assert!(testnet_migration_data.restored_receipts.is_empty());
}
Expand Down
16 changes: 8 additions & 8 deletions chain/chain/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use near_primitives::state_part::PartId;
use near_primitives::transaction::SignedTransaction;
use near_primitives::trie_key::TrieKey;
use near_primitives::types::{
AccountId, Balance, BlockHeight, EpochHeight, EpochId, EpochInfoProvider, Gas, MerkleHash,
ShardId, StateChangeCause, StateRoot, StateRootNode,
new_shard_id_tmp, AccountId, Balance, BlockHeight, EpochHeight, EpochId, EpochInfoProvider,
Gas, MerkleHash, ShardId, StateChangeCause, StateRoot, StateRootNode,
};
use near_primitives::version::{ProtocolFeature, ProtocolVersion};
use near_primitives::views::{
Expand Down Expand Up @@ -223,7 +223,7 @@ impl NightshadeRuntime {
epoch_manager.get_epoch_id_from_prev_block(prev_hash).map_err(Error::from)?;
let shard_version =
epoch_manager.get_shard_layout(&epoch_id).map_err(Error::from)?.version();
Ok(ShardUId { version: shard_version, shard_id: shard_id as u32 })
Ok(ShardUId { version: shard_version, shard_id: new_shard_id_tmp(shard_id) as u32 })
}

fn get_shard_uid_from_epoch_id(
Expand All @@ -234,7 +234,7 @@ impl NightshadeRuntime {
let epoch_manager = self.epoch_manager.read();
let shard_version =
epoch_manager.get_shard_layout(epoch_id).map_err(Error::from)?.version();
Ok(ShardUId { version: shard_version, shard_id: shard_id as u32 })
Ok(ShardUId { version: shard_version, shard_id: new_shard_id_tmp(shard_id) as u32 })
}

fn account_id_to_shard_uid(
Expand Down Expand Up @@ -566,7 +566,7 @@ impl NightshadeRuntime {
target: "runtime",
"obtain_state_part",
part_id = part_id.idx,
shard_id,
?shard_id,
%prev_hash,
num_parts = part_id.total)
.entered();
Expand Down Expand Up @@ -953,7 +953,7 @@ impl RuntimeAdapter for NightshadeRuntime {
}
}

#[instrument(target = "runtime", level = "info", skip_all, fields(shard_id = chunk.shard_id))]
#[instrument(target = "runtime", level = "info", skip_all, fields(shard_id = ?chunk.shard_id))]
fn apply_chunk(
&self,
storage_config: RuntimeStorageConfig,
Expand Down Expand Up @@ -1191,7 +1191,7 @@ impl RuntimeAdapter for NightshadeRuntime {
target: "runtime",
"obtain_state_part",
part_id = part_id.idx,
shard_id,
?shard_id,
%prev_hash,
?state_root,
num_parts = part_id.total)
Expand Down Expand Up @@ -1369,7 +1369,7 @@ fn chunk_tx_gas_limit(
protocol_version: u32,
runtime_config: &RuntimeConfig,
prev_block: &PrepareTransactionsBlockContext,
shard_id: u64,
shard_id: ShardId,
gas_limit: u64,
) -> u64 {
if !ProtocolFeature::CongestionControl.enabled(protocol_version) {
Expand Down
65 changes: 42 additions & 23 deletions chain/chain/src/runtime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,13 @@ impl TestEnv {
// TODO(congestion_control): pass down prev block info and read congestion info from there
// For now, just use default.
let prev_block_hash = self.head.last_block_hash;
let state_root = self.state_roots[shard_id as usize];
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap();
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let state_root = self.state_roots[shard_index];
let gas_limit = u64::MAX;
let height = self.head.height + 1;
let block_timestamp = 0;
let epoch_id =
self.epoch_manager.get_epoch_id_from_prev_block(&prev_block_hash).unwrap_or_default();
let protocol_version = self.epoch_manager.get_epoch_protocol_version(&epoch_id).unwrap();
let gas_price = self.runtime.genesis_config.min_gas_price;
let congestion_info = if !ProtocolFeature::CongestionControl.enabled(protocol_version) {
Expand Down Expand Up @@ -316,19 +317,21 @@ impl TestEnv {
) {
let new_hash = hash(&[(self.head.height + 1) as u8]);
let shard_ids = self.epoch_manager.shard_ids(&self.head.epoch_id).unwrap();
let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap();
assert_eq!(transactions.len(), shard_ids.len());
assert_eq!(chunk_mask.len(), shard_ids.len());
let mut all_proposals = vec![];
let mut all_receipts = vec![];
for shard_id in shard_ids {
let shard_index = shard_layout.get_shard_index(shard_id);
let (state_root, proposals, receipts) = self.update_runtime(
shard_id,
new_hash,
&transactions[shard_id as usize],
&transactions[shard_index],
self.last_receipts.get(&shard_id).map_or(&[], |v| v.as_slice()),
challenges_result.clone(),
);
self.state_roots[shard_id as usize] = state_root;
self.state_roots[shard_index] = state_root;
all_receipts.extend(receipts);
all_proposals.append(&mut proposals.clone());
self.last_shard_proposals.insert(shard_id, proposals);
Expand Down Expand Up @@ -391,9 +394,11 @@ impl TestEnv {
&self.head.epoch_id,
)
.unwrap();
let shard_layout = self.epoch_manager.get_shard_layout(&self.head.epoch_id).unwrap();
let shard_index = shard_layout.get_shard_index(shard_id);
let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &self.head.epoch_id).unwrap();
self.runtime
.view_account(&shard_uid, self.state_roots[shard_id as usize], account_id)
.view_account(&shard_uid, self.state_roots[shard_index], account_id)
.unwrap()
.into()
}
Expand Down Expand Up @@ -727,9 +732,12 @@ fn test_state_sync() {
let block_hash = hash(&[env.head.height as u8]);
let state_part = env
.runtime
.obtain_state_part(0, &block_hash, &env.state_roots[0], PartId::new(0, 1))
.obtain_state_part(new_shard_id_tmp(0), &block_hash, &env.state_roots[0], PartId::new(0, 1))
.unwrap();
let root_node = env
.runtime
.get_state_root_node(new_shard_id_tmp(0), &block_hash, &env.state_roots[0])
.unwrap();
let root_node = env.runtime.get_state_root_node(0, &block_hash, &env.state_roots[0]).unwrap();
let mut new_env = TestEnv::new(vec![validators], 2, false);
for i in 1..=2 {
let prev_hash = hash(&[new_env.head.height as u8]);
Expand Down Expand Up @@ -786,7 +794,13 @@ fn test_state_sync() {
let epoch_id = &new_env.head.epoch_id;
new_env
.runtime
.apply_state_part(0, &env.state_roots[0], PartId::new(0, 1), &state_part, epoch_id)
.apply_state_part(
new_shard_id_tmp(0),
&env.state_roots[0],
PartId::new(0, 1),
&state_part,
epoch_id,
)
.unwrap();
new_env.state_roots[0] = env.state_roots[0];
for _ in 3..=5 {
Expand Down Expand Up @@ -827,9 +841,9 @@ fn test_get_validator_info() {
let height = env.head.height;
let em = env.runtime.epoch_manager.read();
let bp = em.get_block_producer_info(&epoch_id, height).unwrap();
let cp = em.get_chunk_producer_info(&epoch_id, height, 0).unwrap();
let cp = em.get_chunk_producer_info(&epoch_id, height, new_shard_id_tmp(0)).unwrap();
let stateless_validators =
em.get_chunk_validator_assignments(&epoch_id, 0, height).ok();
em.get_chunk_validator_assignments(&epoch_id, new_shard_id_tmp(0), height).ok();

if let Some(vs) = stateless_validators {
if vs.contains(&validators[0]) {
Expand Down Expand Up @@ -876,7 +890,7 @@ fn test_get_validator_info() {
public_key: block_producers[0].public_key(),
is_slashed: false,
stake: TESTING_INIT_STAKE,
shards: vec![0],
shards: vec![new_shard_id_tmp(0)],
num_produced_blocks: expected_blocks[0],
num_expected_blocks: expected_blocks[0],
num_produced_chunks: expected_chunks[0],
Expand All @@ -893,7 +907,7 @@ fn test_get_validator_info() {
public_key: block_producers[1].public_key(),
is_slashed: false,
stake: TESTING_INIT_STAKE,
shards: vec![0],
shards: vec![new_shard_id_tmp(0)],
num_produced_blocks: expected_blocks[1],
num_expected_blocks: expected_blocks[1],
num_produced_chunks: expected_chunks[1],
Expand All @@ -911,13 +925,13 @@ fn test_get_validator_info() {
account_id: "test1".parse().unwrap(),
public_key: block_producers[0].public_key(),
stake: TESTING_INIT_STAKE,
shards: vec![0],
shards: vec![new_shard_id_tmp(0)],
},
NextEpochValidatorInfo {
account_id: "test2".parse().unwrap(),
public_key: block_producers[1].public_key(),
stake: TESTING_INIT_STAKE,
shards: vec![0],
shards: vec![new_shard_id_tmp(0)],
},
];
let response = env
Expand Down Expand Up @@ -988,7 +1002,7 @@ fn test_get_validator_info() {
account_id: "test2".parse().unwrap(),
public_key: block_producers[1].public_key(),
stake: TESTING_INIT_STAKE,
shards: vec![0],
shards: vec![new_shard_id_tmp(0)],
}]
);
assert!(response.current_proposals.is_empty());
Expand Down Expand Up @@ -1461,13 +1475,13 @@ fn test_flat_state_usage() {
let env = TestEnv::new(vec![vec!["test1".parse().unwrap()]], 4, false);
let trie = env
.runtime
.get_trie_for_shard(0, &env.head.prev_block_hash, Trie::EMPTY_ROOT, true)
.get_trie_for_shard(new_shard_id_tmp(0), &env.head.prev_block_hash, Trie::EMPTY_ROOT, true)
.unwrap();
assert!(trie.has_flat_storage_chunk_view());

let trie = env
.runtime
.get_view_trie_for_shard(0, &env.head.prev_block_hash, Trie::EMPTY_ROOT)
.get_view_trie_for_shard(new_shard_id_tmp(0), &env.head.prev_block_hash, Trie::EMPTY_ROOT)
.unwrap();
assert!(!trie.has_flat_storage_chunk_view());
}
Expand Down Expand Up @@ -1505,9 +1519,14 @@ fn test_trie_and_flat_state_equality() {
// - using view state, which should never use flat state
let head_prev_block_hash = env.head.prev_block_hash;
let state_root = env.state_roots[0];
let state = env.runtime.get_trie_for_shard(0, &head_prev_block_hash, state_root, true).unwrap();
let view_state =
env.runtime.get_view_trie_for_shard(0, &head_prev_block_hash, state_root).unwrap();
let state = env
.runtime
.get_trie_for_shard(new_shard_id_tmp(0), &head_prev_block_hash, state_root, true)
.unwrap();
let view_state = env
.runtime
.get_view_trie_for_shard(new_shard_id_tmp(0), &head_prev_block_hash, state_root)
.unwrap();
let trie_key = TrieKey::Account { account_id: validators[1].clone() };
let key = trie_key.to_vec();

Expand Down Expand Up @@ -1651,7 +1670,7 @@ fn prepare_transactions(
transaction_groups: &mut dyn TransactionGroupIterator,
storage_config: RuntimeStorageConfig,
) -> Result<PreparedTransactions, Error> {
let shard_id = 0;
let shard_id = new_shard_id_tmp(0);
let block = chain.get_block(&env.head.prev_block_hash).unwrap();
let congestion_info = block.block_congestion_info();

Expand Down Expand Up @@ -1773,7 +1792,7 @@ fn test_prepare_transactions_empty_storage_proof() {
#[test]
#[cfg_attr(not(feature = "test_features"), ignore)]
fn test_storage_proof_garbage() {
let shard_id = 0;
let shard_id = new_shard_id_tmp(0);
let signer = create_test_signer("test1");
let env = TestEnv::new(vec![vec![signer.validator_id().clone()]], 100, false);
let garbage_size_mb = 50usize;
Expand Down
Loading

0 comments on commit 99ecfa4

Please sign in to comment.