Skip to content

Commit

Permalink
Support customized raft message rejection logic (tikv#18114)
Browse files Browse the repository at this point in the history
close tikv#18113

Support customized raft message rejection logic

Signed-off-by: Calvin Neo <CalvinNeo@users.noreply.github.com>
Signed-off-by: Calvin Neo <calvinneo1995@gmail.com>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
Co-authored-by: glorv <glorvs@163.com>
  • Loading branch information
3 people authored Jan 20, 2025
1 parent d43fea7 commit c88260c
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 15 deletions.
5 changes: 4 additions & 1 deletion components/server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ use tikv::{
lock_manager::LockManager,
raftkv::ReplicaReadLockChecker,
resolve,
service::{DebugService, DiagnosticsService},
service::{DebugService, DefaultGrpcMessageFilter, DiagnosticsService},
status_server::StatusServer,
tablet_snap::NoSnapshotCache,
ttl::TtlChecker,
Expand Down Expand Up @@ -891,6 +891,9 @@ where
debug_thread_pool,
health_controller,
self.resource_manager.clone(),
Arc::new(DefaultGrpcMessageFilter::new(
server_config.value().reject_messages_on_memory_ratio,
)),
)
.unwrap_or_else(|e| fatal!("failed to create server: {}", e));
cfg_controller.register(
Expand Down
5 changes: 4 additions & 1 deletion components/server/src/server2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ use tikv::{
lock_manager::LockManager,
raftkv::ReplicaReadLockChecker,
resolve,
service::{DebugService, DiagnosticsService},
service::{DebugService, DefaultGrpcMessageFilter, DiagnosticsService},
status_server::StatusServer,
KvEngineFactoryBuilder, NodeV2, RaftKv2, Server, CPU_CORES_QUOTA_GAUGE, GRPC_THREAD_PREFIX,
MEMORY_LIMIT_GAUGE,
Expand Down Expand Up @@ -829,6 +829,9 @@ where
debug_thread_pool,
health_controller,
self.resource_manager.clone(),
Arc::new(DefaultGrpcMessageFilter::new(
server_config.value().reject_messages_on_memory_ratio,
)),
)
.unwrap_or_else(|e| fatal!("failed to create server: {}", e));
cfg_controller.register(
Expand Down
5 changes: 4 additions & 1 deletion components/test_raftstore-v2/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use tikv::{
lock_manager::LockManager,
raftkv::ReplicaReadLockChecker,
resolve,
service::{DebugService, DiagnosticsService},
service::{DebugService, DefaultGrpcMessageFilter, DiagnosticsService},
ConnectionBuilder, Error, Extension, NodeV2, PdStoreAddrResolver, RaftClient, RaftKv2,
Result as ServerResult, Server, ServerTransport,
},
Expand Down Expand Up @@ -644,6 +644,9 @@ impl<EK: KvEngine> ServerCluster<EK> {
debug_thread_pool.clone(),
health_controller.clone(),
resource_manager.clone(),
Arc::new(DefaultGrpcMessageFilter::new(
server_cfg.value().reject_messages_on_memory_ratio,
)),
)
.unwrap();
svr.register_service(create_diagnostics(diag_service.clone()));
Expand Down
5 changes: 4 additions & 1 deletion components/test_raftstore/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ use tikv::{
lock_manager::LockManager,
raftkv::ReplicaReadLockChecker,
resolve::{self, StoreAddrResolver},
service::DebugService,
service::{DebugService, DefaultGrpcMessageFilter},
tablet_snap::NoSnapshotCache,
ConnectionBuilder, Error, MultiRaftServer, PdStoreAddrResolver, RaftClient, RaftKv,
Result as ServerResult, Server, ServerTransport,
Expand Down Expand Up @@ -617,6 +617,9 @@ impl ServerCluster {
debug_thread_pool.clone(),
health_controller.clone(),
resource_manager.clone(),
Arc::new(DefaultGrpcMessageFilter::new(
server_cfg.value().reject_messages_on_memory_ratio,
)),
)
.unwrap();
svr.register_service(create_import_sst(import_service.clone()));
Expand Down
5 changes: 5 additions & 0 deletions src/server/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,11 @@ lazy_static! {
"Count for rejected Raft append messages"
)
.unwrap();
pub static ref RAFT_SNAPSHOT_REJECTS: IntCounter = register_int_counter!(
"tikv_server_raft_snapshot_rejects",
"Count for rejected Raft snapshot messages"
)
.unwrap();
pub static ref SNAP_LIMIT_TRANSPORT_BYTES_COUNTER: IntCounterVec = register_int_counter_vec!(
"tikv_snapshot_limit_transport_bytes",
"Total snapshot limit transport used",
Expand Down
3 changes: 3 additions & 0 deletions src/server/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ where
debug_thread_pool: Arc<Runtime>,
health_controller: HealthController,
resource_manager: Option<Arc<ResourceGroupManager>>,
raft_message_filter: Arc<dyn RaftGrpcMessageFilter>,
) -> Result<Self> {
// A helper thread (or pool) for transport layer.
let stats_pool = if cfg.value().stats_concurrency > 0 {
Expand Down Expand Up @@ -211,6 +212,7 @@ where
resource_manager,
health_controller.clone(),
health_feedback_interval,
raft_message_filter,
);
let builder_factory = Box::new(BuilderFactory::new(
kv_service,
Expand Down Expand Up @@ -683,6 +685,7 @@ mod tests {
debug_thread_pool,
HealthController::new(),
None,
Arc::new(DefaultGrpcMessageFilter::new(0.2)),
)
.unwrap();

Expand Down
67 changes: 58 additions & 9 deletions src/server/service/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@ use crate::{
const GRPC_MSG_MAX_BATCH_SIZE: usize = 128;
const GRPC_MSG_NOTIFY_SIZE: usize = 8;

pub trait RaftGrpcMessageFilter: Send + Sync {
fn should_reject_raft_message(&self, _: &RaftMessage) -> bool;
fn should_reject_snapshot(&self) -> bool;
}

// The default filter is exported for other engines as reference.
#[derive(Clone)]
pub struct DefaultGrpcMessageFilter {
reject_messages_on_memory_ratio: f64,
}

impl DefaultGrpcMessageFilter {
pub fn new(reject_messages_on_memory_ratio: f64) -> Self {
Self {
reject_messages_on_memory_ratio,
}
}
}

impl RaftGrpcMessageFilter for DefaultGrpcMessageFilter {
fn should_reject_raft_message(&self, msg: &RaftMessage) -> bool {
fail::fail_point!("force_reject_raft_append_message", |_| true);
if msg.get_message().get_msg_type() == MessageType::MsgAppend {
needs_reject_raft_append(self.reject_messages_on_memory_ratio)
} else {
false
}
}

fn should_reject_snapshot(&self) -> bool {
fail::fail_point!("force_reject_raft_snapshot_message", |_| true);
false
}
}

/// Service handles the RPC messages for the `Tikv` service.
pub struct Service<E: Engine, L: LockManager, F: KvFormat> {
cluster_id: u64,
Expand Down Expand Up @@ -103,6 +138,8 @@ pub struct Service<E: Engine, L: LockManager, F: KvFormat> {
health_controller: HealthController,
health_feedback_interval: Option<Duration>,
health_feedback_seq: Arc<AtomicU64>,

raft_message_filter: Arc<dyn RaftGrpcMessageFilter>,
}

impl<E: Engine, L: LockManager, F: KvFormat> Drop for Service<E, L, F> {
Expand Down Expand Up @@ -130,6 +167,7 @@ impl<E: Engine + Clone, L: LockManager + Clone, F: KvFormat> Clone for Service<E
health_controller: self.health_controller.clone(),
health_feedback_seq: self.health_feedback_seq.clone(),
health_feedback_interval: self.health_feedback_interval,
raft_message_filter: self.raft_message_filter.clone(),
}
}
}
Expand All @@ -152,6 +190,7 @@ impl<E: Engine, L: LockManager, F: KvFormat> Service<E, L, F> {
resource_manager: Option<Arc<ResourceGroupManager>>,
health_controller: HealthController,
health_feedback_interval: Option<Duration>,
raft_message_filter: Arc<dyn RaftGrpcMessageFilter>,
) -> Self {
let now_unix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
Expand All @@ -174,14 +213,15 @@ impl<E: Engine, L: LockManager, F: KvFormat> Service<E, L, F> {
health_controller,
health_feedback_interval,
health_feedback_seq: Arc::new(AtomicU64::new(now_unix)),
raft_message_filter,
}
}

fn handle_raft_message(
store_id: u64,
ch: &E::RaftExtension,
msg: RaftMessage,
reject: bool,
raft_msg_filter: &Arc<dyn RaftGrpcMessageFilter>,
) -> RaftStoreResult<()> {
let to_store_id = msg.get_to_peer().get_store_id();
if to_store_id != store_id {
Expand All @@ -190,8 +230,11 @@ impl<E: Engine, L: LockManager, F: KvFormat> Service<E, L, F> {
my_store_id: store_id,
});
}
if reject && msg.get_message().get_msg_type() == MessageType::MsgAppend {
RAFT_APPEND_REJECTS.inc();

if raft_msg_filter.should_reject_raft_message(&msg) {
if msg.get_message().get_msg_type() == MessageType::MsgAppend {
RAFT_APPEND_REJECTS.inc();
}
let id = msg.get_region_id();
let peer_id = msg.get_message().get_from();
ch.report_reject_message(id, peer_id);
Expand Down Expand Up @@ -753,16 +796,15 @@ impl<E: Engine, L: LockManager, F: KvFormat> Tikv for Service<E, L, F> {

let store_id = self.store_id;
let ch = self.storage.get_engine().raft_extension();
let reject_messages_on_memory_ratio = self.reject_messages_on_memory_ratio;
let ob = self.raft_message_filter.clone();

let res = async move {
let mut stream = stream.map_err(Error::from);
while let Some(msg) = stream.try_next().await? {
RAFT_MESSAGE_RECV_COUNTER.inc();

let reject = needs_reject_raft_append(reject_messages_on_memory_ratio);
if let Err(err @ RaftStoreError::StoreNotMatch { .. }) =
Self::handle_raft_message(store_id, &ch, msg, reject)
Self::handle_raft_message(store_id, &ch, msg, &ob)
{
// Return an error here will break the connection, only do that for
// `StoreNotMatch` to let tikv to resolve a correct address from PD
Expand Down Expand Up @@ -807,7 +849,7 @@ impl<E: Engine, L: LockManager, F: KvFormat> Tikv for Service<E, L, F> {

let store_id = self.store_id;
let ch = self.storage.get_engine().raft_extension();
let reject_messages_on_memory_ratio = self.reject_messages_on_memory_ratio;
let ob = self.raft_message_filter.clone();

let res = async move {
let mut stream = stream.map_err(Error::from);
Expand All @@ -822,10 +864,10 @@ impl<E: Engine, L: LockManager, F: KvFormat> Tikv for Service<E, L, F> {
let len = batch_msg.get_msgs().len();
RAFT_MESSAGE_RECV_COUNTER.inc_by(len as u64);
RAFT_MESSAGE_BATCH_SIZE.observe(len as f64);
let reject = needs_reject_raft_append(reject_messages_on_memory_ratio);

for msg in batch_msg.take_msgs().into_iter() {
if let Err(err @ RaftStoreError::StoreNotMatch { .. }) =
Self::handle_raft_message(store_id, &ch, msg, reject)
Self::handle_raft_message(store_id, &ch, msg, &ob)
{
// Return an error here will break the connection, only do that for
// `StoreNotMatch` to let tikv to resolve a correct address from PD
Expand Down Expand Up @@ -862,6 +904,13 @@ impl<E: Engine, L: LockManager, F: KvFormat> Tikv for Service<E, L, F> {
stream: RequestStream<SnapshotChunk>,
sink: ClientStreamingSink<Done>,
) {
if self.raft_message_filter.should_reject_snapshot() {
RAFT_SNAPSHOT_REJECTS.inc();
let status =
RpcStatus::with_message(RpcStatusCode::UNAVAILABLE, "rejected by peer".to_string());
ctx.spawn(sink.fail(status).map(|_| ()));
return;
};
let task = SnapTask::Recv { stream, sink };
if let Err(e) = self.snap_scheduler.schedule(task) {
let err_msg = format!("{}", e);
Expand Down
4 changes: 2 additions & 2 deletions src/server/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub use self::{
diagnostics::Service as DiagnosticsService,
kv::{
batch_commands_request, batch_commands_response, future_flashback_to_version,
future_prepare_flashback_to_version, GrpcRequestDuration, MeasuredBatchResponse,
MeasuredSingleResponse, Service as KvService,
future_prepare_flashback_to_version, DefaultGrpcMessageFilter, GrpcRequestDuration,
MeasuredBatchResponse, MeasuredSingleResponse, RaftGrpcMessageFilter, Service as KvService,
},
};

Expand Down
44 changes: 44 additions & 0 deletions tests/failpoints/cases/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,47 @@ fn test_serving_status() {
thread::sleep(Duration::from_millis(200));
assert_eq!(check(), ServingStatus::Serving);
}

#[test]
fn test_raft_message_observer() {
let mut cluster = new_server_cluster(0, 3);
cluster.pd_client.disable_default_operator();
let r1 = cluster.run_conf_change();

cluster.must_put(b"k1", b"v1");

fail::cfg("force_reject_raft_append_message", "return").unwrap();
fail::cfg("force_reject_raft_snapshot_message", "return").unwrap();

cluster.pd_client.add_peer(r1, new_peer(2, 2));

std::thread::sleep(std::time::Duration::from_millis(500));

must_get_none(&cluster.get_engine(2), b"k1");

fail::remove("force_reject_raft_append_message");
fail::remove("force_reject_raft_snapshot_message");

cluster.pd_client.must_have_peer(r1, new_peer(2, 2));
cluster.pd_client.must_add_peer(r1, new_peer(3, 3));

must_get_equal(&cluster.get_engine(2), b"k1", b"v1");
must_get_equal(&cluster.get_engine(3), b"k1", b"v1");

fail::cfg("force_reject_raft_append_message", "return").unwrap();

let _ = cluster.async_put(b"k2", b"v2").unwrap();

std::thread::sleep(std::time::Duration::from_millis(500));

must_get_none(&cluster.get_engine(2), b"k2");
must_get_none(&cluster.get_engine(3), b"k2");

fail::remove("force_reject_raft_append_message");

cluster.must_put(b"k3", b"v3");
for id in 1..=3 {
must_get_equal(&cluster.get_engine(id), b"k3", b"v3");
}
cluster.shutdown();
}

0 comments on commit c88260c

Please sign in to comment.