diff --git a/node_test.go b/node_test.go index 068002b8..9da10b49 100644 --- a/node_test.go +++ b/node_test.go @@ -206,6 +206,47 @@ func TestDisableProposalForwarding(t *testing.T) { } } +func TestDisableProposalForwardingCallback(t *testing.T) { + r1 := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3))) + r2 := newTestRaft(2, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3))) + cfg3 := newTestConfig(3, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3))) + dropMessageData := []byte("testdata_should_be_dropped") + cfg3.DisableProposalForwardingCallback = func(m raftpb.Message) bool { + for _, entry := range m.Entries { + if bytes.Equal(entry.Data, dropMessageData) { + return true + } + } + return false + } + r3 := newRaft(cfg3) + nt := newNetwork(r1, r2, r3) + + // elect r1 as leader + nt.send(raftpb.Message{From: 1, To: 1, Type: raftpb.MsgHup}) + + var testEntriesShouldBeForwarded = []raftpb.Entry{{Data: []byte("testdata")}} + var testEntriesShouldBeDropped = []raftpb.Entry{{Data: dropMessageData}, {Data: []byte("testdata")}} + + // send proposal to r2(follower) where DisableProposalForwardingCallback is nil + r2.Step(raftpb.Message{From: 2, To: 2, Type: raftpb.MsgProp, Entries: testEntriesShouldBeForwarded}) + r2.Step(raftpb.Message{From: 2, To: 2, Type: raftpb.MsgProp, Entries: testEntriesShouldBeDropped}) + + // verify r2(follower) does forward the proposal when DisableProposalForwardingCallback is nil + if len(r2.msgs) != 2 { + t.Fatalf("len(r2.msgs) expected 2, got %d", len(r2.msgs)) + } + + // send proposal to r3(follower) where DisableProposalForwardingCallback checks the entries + r3.Step(raftpb.Message{From: 3, To: 3, Type: raftpb.MsgProp, Entries: testEntriesShouldBeForwarded}) + r3.Step(raftpb.Message{From: 3, To: 3, Type: raftpb.MsgProp, Entries: testEntriesShouldBeDropped}) + + // verify r3(follower) does not forward the proposal when the callback returns true + if len(r3.msgs) != 1 { + t.Fatalf("len(r3.msgs) expected 1, got %d", len(r3.msgs)) + } +} + // TestNodeReadIndexToOldLeader ensures that raftpb.MsgReadIndex to old leader // gets forwarded to the new leader and 'send' method does not attach its term. func TestNodeReadIndexToOldLeader(t *testing.T) { diff --git a/raft.go b/raft.go index 86916626..088406f2 100644 --- a/raft.go +++ b/raft.go @@ -283,6 +283,13 @@ type Config struct { // This behavior will become unconditional in the future. See: // https://github.com/etcd-io/raft/issues/83 StepDownOnRemoval bool + + // DisableProposalForwardingCallback will be called for each MsgProp message + // on nodes which are in follower state. + // If this callback returns true, the message will be discarded. + // This callback function is used for implementing a mechanism like + // DisableProposalForwarding for each message instead of global configuration. + DisableProposalForwardingCallback func(m pb.Message) bool } func (c *Config) validate() error { @@ -413,9 +420,10 @@ type raft struct { // randomizedElectionTimeout is a random number between // [electiontimeout, 2 * electiontimeout - 1]. It gets reset // when raft changes its state to follower or candidate. - randomizedElectionTimeout int - disableProposalForwarding bool - stepDownOnRemoval bool + randomizedElectionTimeout int + disableProposalForwarding bool + stepDownOnRemoval bool + disableProposalForwardingCallback func(m pb.Message) bool tick func() step stepFunc @@ -440,22 +448,23 @@ func newRaft(c *Config) *raft { } r := &raft{ - id: c.ID, - lead: None, - isLearner: false, - raftLog: raftlog, - maxMsgSize: entryEncodingSize(c.MaxSizePerMsg), - maxUncommittedSize: entryPayloadSize(c.MaxUncommittedEntriesSize), - prs: tracker.MakeProgressTracker(c.MaxInflightMsgs, c.MaxInflightBytes), - electionTimeout: c.ElectionTick, - heartbeatTimeout: c.HeartbeatTick, - logger: c.Logger, - checkQuorum: c.CheckQuorum, - preVote: c.PreVote, - readOnly: newReadOnly(c.ReadOnlyOption), - disableProposalForwarding: c.DisableProposalForwarding, - disableConfChangeValidation: c.DisableConfChangeValidation, - stepDownOnRemoval: c.StepDownOnRemoval, + id: c.ID, + lead: None, + isLearner: false, + raftLog: raftlog, + maxMsgSize: entryEncodingSize(c.MaxSizePerMsg), + maxUncommittedSize: entryPayloadSize(c.MaxUncommittedEntriesSize), + prs: tracker.MakeProgressTracker(c.MaxInflightMsgs, c.MaxInflightBytes), + electionTimeout: c.ElectionTick, + heartbeatTimeout: c.HeartbeatTick, + logger: c.Logger, + checkQuorum: c.CheckQuorum, + preVote: c.PreVote, + readOnly: newReadOnly(c.ReadOnlyOption), + disableProposalForwarding: c.DisableProposalForwarding, + disableConfChangeValidation: c.DisableConfChangeValidation, + stepDownOnRemoval: c.StepDownOnRemoval, + disableProposalForwardingCallback: c.DisableProposalForwardingCallback, } cfg, prs, err := confchange.Restore(confchange.Changer{ @@ -1676,6 +1685,11 @@ func stepFollower(r *raft, m pb.Message) error { } else if r.disableProposalForwarding { r.logger.Infof("%x not forwarding to leader %x at term %d; dropping proposal", r.id, r.lead, r.Term) return ErrProposalDropped + } else if r.disableProposalForwardingCallback != nil && r.disableProposalForwardingCallback(m) { + r.logger.Infof("%x not forwarding to leader %x at term %d"+ + " because disableProposalForwardingCallback() returned true for the message; dropping proposal", + r.id, r.lead, r.Term) + return ErrProposalDropped } m.To = r.lead r.send(m)