Skip to content

Commit

Permalink
Revert "Revert "Add currentDirection to RTPTransceiver""
Browse files Browse the repository at this point in the history
This reverts commit a92c400.
  • Loading branch information
jeremija committed Jan 28, 2023
1 parent a92c400 commit 5b41ed6
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 10 deletions.
56 changes: 55 additions & 1 deletion peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,7 @@ func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error {
weAnswer := desc.Type == SDPTypeAnswer
remoteDesc := pc.RemoteDescription()
if weAnswer && remoteDesc != nil {
_ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, false)
if err := pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
Expand Down Expand Up @@ -1150,6 +1151,7 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {

if isRenegotation {
if weOffer {
_ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true)
if err = pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
Expand Down Expand Up @@ -1179,6 +1181,7 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
// Start the networking in a new routine since it will block until
// the connection is actually established.
if weOffer {
_ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true)
if err := pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
Expand Down Expand Up @@ -1237,6 +1240,51 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
}
}

func setRTPTransceiverCurrentDirection(answer *SessionDescription, currentTransceivers []*RTPTransceiver, weOffer bool) error {
currentTransceivers = append([]*RTPTransceiver{}, currentTransceivers...)
for _, media := range answer.parsed.MediaDescriptions {
midValue := getMidValue(media)
if midValue == "" {
return errPeerConnRemoteDescriptionWithoutMidValue
}

if media.MediaName.Media == mediaSectionApplication {
continue
}

var t *RTPTransceiver
t, currentTransceivers = findByMid(midValue, currentTransceivers)

if t == nil {
return fmt.Errorf("%w: %q", errPeerConnTranscieverMidNil, midValue)
}

direction := getPeerDirection(media)
if direction == RTPTransceiverDirection(Unknown) {
continue
}

// reverse direction if it was a remote answer
if weOffer {
switch direction {
case RTPTransceiverDirectionSendonly:
direction = RTPTransceiverDirectionRecvonly
case RTPTransceiverDirectionRecvonly:
// Pion will answer recvonly with a offer recvonly transceiver, so we should
// not change the direction to sendonly if we are the offerer, otherwise this
// tranceiver can't be reuse for AddTrack
if t.Direction() != RTPTransceiverDirectionRecvonly {
direction = RTPTransceiverDirectionSendonly
}
default:
}
}

t.setCurrentDirection(direction)
}
return nil
}

func runIfNewReceiver(
incomingTrack trackDetails,
transceivers []*RTPTransceiver,
Expand Down Expand Up @@ -1723,7 +1771,13 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
pc.mu.Lock()
defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers {
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil {
currentDirection := t.getCurrentDirection()
// According to https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-addtrack, if the
// transceiver can be reused only if it's currentDirection never be sendrecv or sendonly.
// But that will cause sdp inflate. So we only check currentDirection's current value,
// that's worked for all browsers.
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil &&
!(currentDirection == RTPTransceiverDirectionSendrecv || currentDirection == RTPTransceiverDirectionSendonly) {
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err == nil {
err = t.SetSender(sender, track)
Expand Down
57 changes: 52 additions & 5 deletions peerconnection_renegotiation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ func TestPeerConnection_Renegotiation_AddRecvonlyTransceiver(t *testing.T) {
pcOffer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
onTrackFiredFunc()
})
assert.NoError(t, signalPair(pcAnswer, pcOffer))
} else {
pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
onTrackFiredFunc()
})
assert.NoError(t, signalPair(pcOffer, pcAnswer))
}

assert.NoError(t, signalPair(pcOffer, pcAnswer))

sendVideoUntilDone(onTrackFired.Done(), t, []*TrackLocalStaticSample{localTrack})

closePairNow(t, pcOffer, pcAnswer)
Expand Down Expand Up @@ -380,6 +380,7 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {

offer, err = pcOffer.CreateOffer(nil)
assert.NoError(t, err)
assert.NoError(t, pcOffer.SetLocalDescription(offer))

assert.Equal(t, len(offer.parsed.MediaDescriptions), 2)

Expand All @@ -391,6 +392,11 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {
pcOffer.ops.Done()
pcAnswer.ops.Done()

assert.NoError(t, pcAnswer.SetRemoteDescription(offer))
answer, err = pcAnswer.CreateAnswer(nil)
assert.NoError(t, err)
assert.NoError(t, pcOffer.SetRemoteDescription(answer))

track3, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion3")
require.NoError(t, err)

Expand Down Expand Up @@ -468,12 +474,12 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) {

require.NoError(t, pcOffer.RemoveTrack(sender1))

sender2, err := pcOffer.AddTrack(track2)
require.NoError(t, err)

require.NoError(t, signalPair(pcOffer, pcAnswer))
<-tracksClosed

sender2, err := pcOffer.AddTrack(track2)
require.NoError(t, err)
require.NoError(t, signalPair(pcOffer, pcAnswer))
transceivers = pcOffer.GetTransceivers()
require.Equal(t, 1, len(transceivers))
require.Equal(t, "0", transceivers[0].Mid())
Expand Down Expand Up @@ -1146,6 +1152,47 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) {
})
}

func TestPeerConnection_Regegotiation_ReuseTransceiver(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

pcOffer, pcAnswer, err := newPair()
if err != nil {
t.Fatal(err)
}

vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "foo", "bar")
assert.NoError(t, err)
sender, err := pcOffer.AddTrack(vp8Track)
assert.NoError(t, err)
assert.NoError(t, signalPair(pcOffer, pcAnswer))

assert.Equal(t, len(pcOffer.GetTransceivers()), 1)
assert.Equal(t, pcOffer.GetTransceivers()[0].getCurrentDirection(), RTPTransceiverDirectionSendonly)
assert.NoError(t, pcOffer.RemoveTrack(sender))
assert.Equal(t, pcOffer.GetTransceivers()[0].getCurrentDirection(), RTPTransceiverDirectionSendonly)

// should not reuse tranceiver
vp8Track2, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "foo", "bar")
assert.NoError(t, err)
sender2, err := pcOffer.AddTrack(vp8Track2)
assert.NoError(t, err)
assert.Equal(t, len(pcOffer.GetTransceivers()), 2)
assert.NoError(t, signalPair(pcOffer, pcAnswer))
assert.True(t, sender2.rtpTransceiver == pcOffer.GetTransceivers()[1])

// should reuse first transceiver
sender, err = pcOffer.AddTrack(vp8Track)
assert.NoError(t, err)
assert.Equal(t, len(pcOffer.GetTransceivers()), 2)
assert.True(t, sender.rtpTransceiver == pcOffer.GetTransceivers()[0])

closePairNow(t, pcOffer, pcAnswer)
}

func TestPeerConnection_Renegotiation_MidConflict(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
Expand Down
22 changes: 18 additions & 4 deletions rtptransceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (

// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid.
type RTPTransceiver struct {
mid atomic.Value // string
sender atomic.Value // *RTPSender
receiver atomic.Value // *RTPReceiver
direction atomic.Value // RTPTransceiverDirection
mid atomic.Value // string
sender atomic.Value // *RTPSender
receiver atomic.Value // *RTPReceiver
direction atomic.Value // RTPTransceiverDirection
currentDirection atomic.Value // RTPTransceiverDirection

codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences

Expand All @@ -38,6 +39,7 @@ func newRTPTransceiver(
t.setReceiver(receiver)
t.setSender(sender)
t.setDirection(direction)
t.setCurrentDirection(RTPTransceiverDirection(Unknown))
return t
}

Expand Down Expand Up @@ -160,6 +162,7 @@ func (t *RTPTransceiver) Stop() error {
}

t.setDirection(RTPTransceiverDirectionInactive)
t.setCurrentDirection(RTPTransceiverDirectionInactive)
return nil
}

Expand All @@ -179,6 +182,17 @@ func (t *RTPTransceiver) setDirection(d RTPTransceiverDirection) {
t.direction.Store(d)
}

func (t *RTPTransceiver) setCurrentDirection(d RTPTransceiverDirection) {
t.currentDirection.Store(d)
}

func (t *RTPTransceiver) getCurrentDirection() RTPTransceiverDirection {
if v, ok := t.currentDirection.Load().(RTPTransceiverDirection); ok {
return v
}
return RTPTransceiverDirection(Unknown)
}

func (t *RTPTransceiver) setSendingTrack(track TrackLocal) error {
if err := t.Sender().ReplaceTrack(track); err != nil {
return err
Expand Down

0 comments on commit 5b41ed6

Please sign in to comment.