Skip to content

Commit

Permalink
feat(network): overhaul state part request (near#12110)
Browse files Browse the repository at this point in the history
In this PR we implement the ability to request state parts from
arbitrary peers in the network via routed messages. Previously, state
parts were requested via a PeerMessage which can only be sent to
directly connected peers of the node.

Because the responses to these requests are large and
non-time-sensitive, it is undesirable to send them over the tier1/tier2
connections used for other operations of the protocol. Hence we also
introduce a new connection pool tier3 used for the sole purpose of
transmitting large one-time payloads.

A separate PR will follow which overhauls the state sync actor in
accordance with these changes. The end-to-end behavior has been built
and tested in near#12095.
  • Loading branch information
saketh-are authored Sep 20, 2024
1 parent c0c172f commit a18810a
Show file tree
Hide file tree
Showing 18 changed files with 362 additions and 59 deletions.
64 changes: 44 additions & 20 deletions chain/client/src/sync/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,24 +675,42 @@ impl StateSync {
for ((part_id, download), target) in
parts_to_fetch(new_shard_sync_download).zip(possible_targets_sampler)
{
sent_request_part(
self.clock.clone(),
target.clone(),
part_id,
shard_id,
sync_hash,
last_part_id_requested,
requested_target,
self.timeout,
);
request_part_from_peers(
part_id,
target,
download,
shard_id,
sync_hash,
&self.network_adapter,
);
// The request sent to the network adapater needs to include the sync_prev_prev_hash
// so that a peer hosting the correct snapshot can be selected.
let prev_header = chain
.get_block_header(&sync_hash)
.map(|header| chain.get_block_header(&header.prev_hash()));

match prev_header {
Ok(Ok(prev_header)) => {
let sync_prev_prev_hash = prev_header.prev_hash();
sent_request_part(
self.clock.clone(),
target.clone(),
part_id,
shard_id,
sync_hash,
last_part_id_requested,
requested_target,
self.timeout,
);
request_part_from_peers(
part_id,
target,
download,
shard_id,
sync_hash,
*sync_prev_prev_hash,
&self.network_adapter,
);
}
Ok(Err(err)) => {
tracing::error!(target: "sync", %shard_id, %sync_hash, ?err, "could not get prev header");
}
Err(err) => {
tracing::error!(target: "sync", %shard_id, %sync_hash, ?err, "could not get header");
}
}
}
}
StateSyncInner::External { chain_id, semaphore, external } => {
Expand Down Expand Up @@ -1304,18 +1322,24 @@ fn request_part_from_peers(
download: &mut DownloadStatus,
shard_id: ShardId,
sync_hash: CryptoHash,
sync_prev_prev_hash: CryptoHash,
network_adapter: &PeerManagerAdapter,
) {
download.run_me.store(false, Ordering::SeqCst);
download.state_requests_count += 1;
download.last_target = Some(peer_id.clone());
download.last_target = Some(peer_id);
let run_me = download.run_me.clone();

near_performance_metrics::actix::spawn(
"StateSync",
network_adapter
.send_async(PeerManagerMessageRequest::NetworkRequests(
NetworkRequests::StateRequestPart { shard_id, sync_hash, part_id, peer_id },
NetworkRequests::StateRequestPart {
shard_id,
sync_hash,
sync_prev_prev_hash,
part_id,
},
))
.then(move |result| {
// TODO: possible optimization - in the current code, even if one of the targets it not present in the network graph
Expand Down
11 changes: 7 additions & 4 deletions chain/client/src/tests/catching_up.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ enum ReceiptsSyncPhases {
pub struct StateRequestStruct {
pub shard_id: u64,
pub sync_hash: CryptoHash,
pub sync_prev_prev_hash: Option<CryptoHash>,
pub part_id: Option<u64>,
pub peer_id: PeerId,
pub peer_id: Option<PeerId>,
}

/// Sanity checks that the incoming and outgoing receipts are properly sent and received
Expand Down Expand Up @@ -268,8 +269,9 @@ fn test_catchup_receipts_sync_common(wait_till: u64, send: u64, sync_hold: bool)
let srs = StateRequestStruct {
shard_id: *shard_id,
sync_hash: *sync_hash,
sync_prev_prev_hash: None,
part_id: None,
peer_id: peer_id.clone(),
peer_id: Some(peer_id.clone()),
};
if !seen_hashes_with_state
.contains(&hash_func(&borsh::to_vec(&srs).unwrap()))
Expand All @@ -283,16 +285,17 @@ fn test_catchup_receipts_sync_common(wait_till: u64, send: u64, sync_hold: bool)
if let NetworkRequests::StateRequestPart {
shard_id,
sync_hash,
sync_prev_prev_hash,
part_id,
peer_id,
} = msg
{
if sync_hold {
let srs = StateRequestStruct {
shard_id: *shard_id,
sync_hash: *sync_hash,
sync_prev_prev_hash: Some(*sync_prev_prev_hash),
part_id: Some(*part_id),
peer_id: peer_id.clone(),
peer_id: None,
};
if !seen_hashes_with_state
.contains(&hash_func(&borsh::to_vec(&srs).unwrap()))
Expand Down
3 changes: 3 additions & 0 deletions chain/network/src/network_protocol/borsh_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ impl From<&mem::PeerMessage> for net::PeerMessage {
panic!("Tier1Handshake is not supported in Borsh encoding")
}
mem::PeerMessage::Tier2Handshake(h) => net::PeerMessage::Handshake((&h).into()),
mem::PeerMessage::Tier3Handshake(_) => {
panic!("Tier3Handshake is not supported in Borsh encoding")
}
mem::PeerMessage::HandshakeFailure(pi, hfr) => {
net::PeerMessage::HandshakeFailure(pi, (&hfr).into())
}
Expand Down
3 changes: 3 additions & 0 deletions chain/network/src/network_protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ pub struct Disconnect {
pub enum PeerMessage {
Tier1Handshake(Handshake),
Tier2Handshake(Handshake),
Tier3Handshake(Handshake),
HandshakeFailure(PeerInfo, HandshakeFailureReason),
/// When a failed nonce is used by some peer, this message is sent back as evidence.
LastEdge(Edge),
Expand Down Expand Up @@ -552,6 +553,7 @@ pub enum RoutedMessageBody {
VersionedChunkEndorsement(ChunkEndorsement),
EpochSyncRequest,
EpochSyncResponse(CompressedEpochSyncProof),
StatePartRequest(StatePartRequest),
}

impl RoutedMessageBody {
Expand Down Expand Up @@ -645,6 +647,7 @@ impl fmt::Debug for RoutedMessageBody {
RoutedMessageBody::EpochSyncResponse(_) => {
write!(f, "EpochSyncResponse")
}
RoutedMessageBody::StatePartRequest(_) => write!(f, "StatePartRequest"),
}
}
}
Expand Down
16 changes: 7 additions & 9 deletions chain/network/src/network_protocol/network.proto
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,15 @@ message PeerMessage {
TraceContext trace_context = 26;

oneof message_type {
// Handshakes for TIER1 and TIER2 networks are considered separate,
// so that a node binary which doesn't support TIER1 connection won't
// be even able to PARSE the handshake. This way we avoid accidental
// connections, such that one end thinks it is a TIER2 connection and the
// other thinks it is a TIER1 connection. As currently both TIER1 and TIER2
// connections are handled by the same PeerActor, both fields use the same
// underlying message type. If we ever decide to separate the handshake
// implementations, we can copy the Handshake message type defition and
// make it evolve differently for TIER1 and TIER2.
// Handshakes for different network tiers explicitly use different PeerMessage variants.
// This way we avoid accidental connections, such that one end thinks it is a TIER2 connection
// and the other thinks it is a TIER1 connection. Currently the same PeerActor handles
// all types of connections, hence the contents are identical for all types of connections.
// If we ever decide to separate the handshake implementations, we can copy the Handshake message
// type definition and make it evolve differently for different tiers.
Handshake tier1_handshake = 27;
Handshake tier2_handshake = 4;
Handshake tier3_handshake = 33;

HandshakeFailure handshake_failure = 5;
LastEdge last_edge = 6;
Expand Down
4 changes: 4 additions & 0 deletions chain/network/src/network_protocol/proto_conv/peer_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ impl From<&PeerMessage> for proto::PeerMessage {
message_type: Some(match x {
PeerMessage::Tier1Handshake(h) => ProtoMT::Tier1Handshake(h.into()),
PeerMessage::Tier2Handshake(h) => ProtoMT::Tier2Handshake(h.into()),
PeerMessage::Tier3Handshake(h) => ProtoMT::Tier3Handshake(h.into()),
PeerMessage::HandshakeFailure(pi, hfr) => {
ProtoMT::HandshakeFailure((pi, hfr).into())
}
Expand Down Expand Up @@ -398,6 +399,9 @@ impl TryFrom<&proto::PeerMessage> for PeerMessage {
ProtoMT::Tier2Handshake(h) => {
PeerMessage::Tier2Handshake(h.try_into().map_err(Self::Error::Handshake)?)
}
ProtoMT::Tier3Handshake(h) => {
PeerMessage::Tier3Handshake(h.try_into().map_err(Self::Error::Handshake)?)
}
ProtoMT::HandshakeFailure(hf) => {
let (pi, hfr) = hf.try_into().map_err(Self::Error::HandshakeFailure)?;
PeerMessage::HandshakeFailure(pi, hfr)
Expand Down
23 changes: 23 additions & 0 deletions chain/network/src/network_protocol/state_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,26 @@ pub enum SnapshotHostInfoVerificationError {
)]
TooManyShards(usize),
}

/// Message used to request a state part.
///
#[derive(
Clone,
Debug,
Eq,
PartialEq,
Hash,
borsh::BorshSerialize,
borsh::BorshDeserialize,
ProtocolSchema,
)]
pub struct StatePartRequest {
/// Requested shard id
pub shard_id: ShardId,
/// Hash of the requested snapshot's state root
pub sync_hash: CryptoHash,
/// Requested part id
pub part_id: u64,
/// Public address of the node making the request
pub addr: std::net::SocketAddr,
}
42 changes: 37 additions & 5 deletions chain/network/src/peer/peer_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,18 @@ impl PeerActor {
.start_outbound(peer_id.clone())
.map_err(ClosingReason::OutboundNotAllowed)?
}
tcp::Tier::T3 => {
// Loop connections are allowed only on T1; see comment above
if peer_id == &network_state.config.node_id() {
return Err(ClosingReason::OutboundNotAllowed(
connection::PoolError::UnexpectedLoopConnection,
));
}
network_state
.tier3
.start_outbound(peer_id.clone())
.map_err(ClosingReason::OutboundNotAllowed)?
}
},
handshake_spec: HandshakeSpec {
partial_edge_info: network_state.propose_edge(&clock, peer_id, None),
Expand All @@ -293,10 +305,12 @@ impl PeerActor {
},
},
};
// Override force_encoding for outbound Tier1 connections,
// since Tier1Handshake is supported only with proto encoding.
// Override force_encoding for outbound Tier1 and Tier3 connections;
// Tier1Handshake and Tier3Handshake are supported only with proto encoding.
let force_encoding = match &stream.type_ {
tcp::StreamType::Outbound { tier, .. } if tier == &tcp::Tier::T1 => {
tcp::StreamType::Outbound { tier, .. }
if tier == &tcp::Tier::T1 || tier == &tcp::Tier::T3 =>
{
Some(Encoding::Proto)
}
_ => force_encoding,
Expand Down Expand Up @@ -480,6 +494,7 @@ impl PeerActor {
let msg = match spec.tier {
tcp::Tier::T1 => PeerMessage::Tier1Handshake(handshake),
tcp::Tier::T2 => PeerMessage::Tier2Handshake(handshake),
tcp::Tier::T3 => PeerMessage::Tier3Handshake(handshake),
};
self.send_message_or_log(&msg);
}
Expand Down Expand Up @@ -939,6 +954,9 @@ impl PeerActor {
(PeerStatus::Connecting { .. }, PeerMessage::Tier2Handshake(msg)) => {
self.process_handshake(ctx, tcp::Tier::T2, msg)
}
(PeerStatus::Connecting { .. }, PeerMessage::Tier3Handshake(msg)) => {
self.process_handshake(ctx, tcp::Tier::T3, msg)
}
(_, msg) => {
tracing::warn!(target:"network","unexpected message during handshake: {}",msg)
}
Expand Down Expand Up @@ -1140,7 +1158,9 @@ impl PeerActor {

self.stop(ctx, ClosingReason::DisconnectMessage);
}
PeerMessage::Tier1Handshake(_) | PeerMessage::Tier2Handshake(_) => {
PeerMessage::Tier1Handshake(_)
| PeerMessage::Tier2Handshake(_)
| PeerMessage::Tier3Handshake(_) => {
// Received handshake after already have seen handshake from this peer.
tracing::debug!(target: "network", "Duplicate handshake from {}", self.peer_info);
}
Expand Down Expand Up @@ -1182,8 +1202,20 @@ impl PeerActor {
self.stop(ctx, ClosingReason::Ban(ReasonForBan::Abusive));
}

// Add received peers to the peer store
let node_id = self.network_state.config.node_id();

// Record our own IP address as observed by the peer.
if self.network_state.my_public_addr.read().is_none() {
if let Some(my_peer_info) =
direct_peers.iter().find(|peer_info| peer_info.id == node_id)
{
if let Some(addr) = my_peer_info.addr {
let mut my_public_addr = self.network_state.my_public_addr.write();
*my_public_addr = Some(addr);
}
}
}
// Add received indirect peers to the peer store
self.network_state.peer_store.add_indirect_peers(
&self.clock,
peers.into_iter().filter(|peer_info| peer_info.id != node_id),
Expand Down
4 changes: 4 additions & 0 deletions chain/network/src/peer_manager/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ impl tcp::Tier {
match msg {
PeerMessage::Tier1Handshake(_) => self == tcp::Tier::T1,
PeerMessage::Tier2Handshake(_) => self == tcp::Tier::T2,
PeerMessage::Tier3Handshake(_) => self == tcp::Tier::T3,
PeerMessage::HandshakeFailure(_, _) => true,
PeerMessage::LastEdge(_) => true,
PeerMessage::VersionedStateResponse(_) => {
self == tcp::Tier::T2 || self == tcp::Tier::T3
}
PeerMessage::Routed(msg) => self.is_allowed_routed(&msg.body),
_ => self == tcp::Tier::T2,
}
Expand Down
Loading

0 comments on commit a18810a

Please sign in to comment.