Skip to content

Commit

Permalink
Ensure verification tag is zero for INIT packets
Browse files Browse the repository at this point in the history
This test verifies that the verification tag is correctly set to 0 for all INIT packets,
including retransmissions.
  • Loading branch information
JoeTurki committed Jan 21, 2025
1 parent 2600de3 commit 3aa79cd
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 1 deletion.
2 changes: 1 addition & 1 deletion association.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (a *Association) sendInit() error {
}

outbound := &packet{}
outbound.verificationTag = a.peerVerificationTag
outbound.verificationTag = 0
a.sourcePort = defaultSCTPSrcDstPort
a.destinationPort = defaultSCTPSrcDstPort
outbound.sourcePort = a.sourcePort
Expand Down
139 changes: 139 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,145 @@ func TestAssocUnreliable(t *testing.T) {
})
}

// This test ensures that verification tag is set to 0 for all INIT packets.
// A test for this PR https://github.com/pion/sctp/pull/341
// We drop the first INIT ACK, and we expect the verification tag to be 0 on
// retransmission.
func TestInitVerificationTagIsZero(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

const si uint16 = 1
const msg = "ABC"
br := test.NewBridge()
ackCount := 0
recvBufSize := uint32(0)

var a0, a1 *Association
var err0, err1 error
loggerFactory := logging.NewDefaultLoggerFactory()

handshake0Ch := make(chan bool)
handshake1Ch := make(chan bool)
fatalChannel := make(chan error)

fitlerFunc := func(pkt []byte) bool {
t.Helper()

packetData := packet{}

assert.NoError(t, packetData.unmarshal(true, pkt))

// Init chunk and Init Ack chunk are never bundled.
if len(packetData.chunks) != 1 {
return true
}

switch packetData.chunks[0].(type) {
case *chunkInit:
if packetData.verificationTag != 0 {
// Even without this we will get WARNING:
// failed validating packet init chunk expects a verification tag of 0 on the packet when out-of-the-blue
// And the connection will fail silently.
go func() {
fatalChannel <- errors.New("verification tag should be 0 for Init chunk") //nolint:err113
}()

return false
}
// Drop the first two Init Ack chunk.
case *chunkInitAck:
ackCount++
return ackCount > 2
}

return true
}

br.Filter(0, fitlerFunc)

br.Filter(1, fitlerFunc)

go func() {
a0, err0 = Client(Config{
Name: "a0",
NetConn: br.GetConn0(),
MaxReceiveBufferSize: recvBufSize,
LoggerFactory: loggerFactory,
})

handshake0Ch <- true
}()
go func() {
a1, err1 = Client(Config{
Name: "a1",
NetConn: br.GetConn1(),
MaxReceiveBufferSize: recvBufSize,
LoggerFactory: loggerFactory,
})
handshake1Ch <- true
}()

a0handshakeDone := false
a1handshakeDone := false

loop1:
for i := 0; i < 1e3; i++ {
time.Sleep(10 * time.Millisecond)
br.Tick()

select {
case a0handshakeDone = <-handshake0Ch:
if a1handshakeDone {
break loop1
}
case a1handshakeDone = <-handshake1Ch:
if a0handshakeDone {
break loop1
}
case err := <-fatalChannel:
t.Fatal(err)
default:
}
}

assert.Equal(t, a0handshakeDone, true, "handshake failed e0")
assert.Equal(t, a1handshakeDone, true, "handshake failed e1")

assert.NoError(t, err0, "failed to create association a0")
assert.NoError(t, err1, "failed to create association a1")

a0.ackMode = ackModeNoDelay
a1.ackMode = ackModeNoDelay

s0, s1, err := establishSessionPair(br, a0, a1, si)
assert.Nil(t, err, "failed to establish session pair")

assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount")

n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary)
if err != nil {
assert.FailNow(t, "failed due to earlier error")
}
assert.Equal(t, len(msg), n, "unexpected length of received data")
assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount")

flushBuffers(br, a0, a1)

buf := make([]byte, 32)
n, ppi, err := s1.ReadSCTP(buf)
if !assert.Nil(t, err, "ReadSCTP failed") {
assert.FailNow(t, "failed due to earlier error")
}
assert.Equal(t, n, len(msg), "unexpected length of received data")
assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi")

assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable")
assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount")

closeAssociationPair(br, a0, a1)
}

func TestCreateForwardTSN(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()

Expand Down

0 comments on commit 3aa79cd

Please sign in to comment.