Skip to content

Commit

Permalink
Update Interceptors to use []byte based API
Browse files Browse the repository at this point in the history
Also update test to assert Attributes get passed all the way through

Resolves pion/interceptor#14
  • Loading branch information
Sean-Der committed Dec 14, 2020
1 parent ff1bc32 commit 67826b1
Show file tree
Hide file tree
Showing 18 changed files with 210 additions and 306 deletions.
2 changes: 1 addition & 1 deletion examples/broadcast/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func main() { // nolint:gocognit

rtpBuf := make([]byte, 1400)
for {
i, readErr := remoteTrack.Read(rtpBuf)
i, _, readErr := remoteTrack.Read(rtpBuf)
if readErr != nil {
panic(readErr)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/reflect/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func main() {
fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType)
for {
// Read RTP packets being sent to Pion
rtp, readErr := track.ReadRTP()
rtp, _, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/rtp-forwarder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func main() {
b := make([]byte, 1500)
for {
// Read
n, readErr := track.Read(b)
n, _, readErr := track.Read(b)
if readErr != nil {
panic(readErr)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/save-to-disk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func saveToDisk(i media.Writer, track *webrtc.TrackRemote) {
}()

for {
rtpPacket, err := track.ReadRTP()
rtpPacket, _, err := track.ReadRTP()
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/simulcast/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func main() {
}()
for {
// Read RTP packets being sent to Pion
packet, readErr := track.ReadRTP()
packet, _, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/swap-tracks/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func main() { // nolint:gocognit
var isCurrTrack bool
for {
// Read RTP packets being sent to Pion
rtp, readErr := track.ReadRTP()
rtp, _, readErr := track.ReadRTP()
if readErr != nil {
panic(readErr)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/pion/datachannel v1.4.21
github.com/pion/dtls/v2 v2.0.4
github.com/pion/ice/v2 v2.0.14
github.com/pion/interceptor v0.0.5
github.com/pion/interceptor v0.0.6
github.com/pion/logging v0.2.2
github.com/pion/randutil v0.1.0
github.com/pion/rtcp v1.2.6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI=
github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI=
github.com/pion/ice/v2 v2.0.14 h1:FxXxauyykf89SWAtkQCfnHkno6G8+bhRkNguSh9zU+4=
github.com/pion/ice/v2 v2.0.14/go.mod h1:wqaUbOq5ObDNU5ox1hRsEst0rWfsKuH1zXjQFEWiZwM=
github.com/pion/interceptor v0.0.5 h1:BOwlubM1lntji3eNaVrhW1Qk3u1UoemrhM4mbv24XGM=
github.com/pion/interceptor v0.0.5/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU=
github.com/pion/interceptor v0.0.6 h1:530EdZi757pZEx510kvO25FkEuKm2mrb0p9NA+Xfj8E=
github.com/pion/interceptor v0.0.6/go.mod h1:QHkPVN5uyuw54wHqqL1KS9fxf3M3RzOlVKg/YrtK1so=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY=
Expand Down
50 changes: 48 additions & 2 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
package webrtc

import (
"sync/atomic"

"github.com/pion/interceptor"
"github.com/pion/rtp"
)

// RegisterDefaultInterceptors will register some useful interceptors. If you want to customize which interceptors are loaded,
// you should copy the code from this method and remove unwanted interceptors.
func RegisterDefaultInterceptors(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error {
err := ConfigureNack(mediaEngine, interceptorRegistry)
if err != nil {
if err := ConfigureNack(mediaEngine, interceptorRegistry); err != nil {
return err
}

Expand All @@ -24,3 +26,47 @@ func ConfigureNack(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Re
interceptorRegistry.Add(&interceptor.NACK{})
return nil
}

type interceptorToTrackLocalWriter struct{ interceptor atomic.Value } // interceptor.RTPWriter }

func (i *interceptorToTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
if writer, ok := i.interceptor.Load().(interceptor.RTPWriter); ok && writer != nil {
return writer.Write(header, payload, interceptor.Attributes{})
}

return 0, nil
}

func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) {
packet := &rtp.Packet{}
if err := packet.Unmarshal(b); err != nil {
return 0, err
}

return i.WriteRTP(&packet.Header, packet.Payload)
}

func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) interceptor.StreamInfo {
headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions))
for _, h := range webrtcHeaderExtensions {
headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI})
}

feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback))
for _, f := range codec.RTCPFeedback {
feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter})
}

return interceptor.StreamInfo{
ID: id,
Attributes: interceptor.Attributes{},
SSRC: uint32(ssrc),
PayloadType: uint8(payloadType),
RTPHeaderExtensions: headerExtensions,
MimeType: codec.MimeType,
ClockRate: codec.ClockRate,
Channels: codec.Channels,
SDPFmtpLine: codec.SDPFmtpLine,
RTCPFeedback: feedbacks,
}
}
144 changes: 39 additions & 105 deletions interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,172 +2,106 @@

package webrtc

//
import (
"sync"
"sync/atomic"
"context"
"testing"
"time"

"github.com/pion/interceptor"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/transport/test"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/stretchr/testify/assert"
)

type testInterceptor struct {
t *testing.T
extensionID uint8
rtcpWriter atomic.Value
lastRTCP atomic.Value
interceptor.NoOp

t *testing.T
}

func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter {
return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) {
return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
// set extension on outgoing packet
p.Header.Extension = true
p.Header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write")))
header.Extension = true
header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, header.SetExtension(2, []byte("foo")))

return writer.Write(p, attributes)
return writer.Write(header, payload, attributes)
})
}

func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) {
p, attributes, err := reader.Read()
if err != nil {
return nil, nil, err
func (t *testInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader {
return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
if a == nil {
a = interceptor.Attributes{}
}
// set extension on incoming packet
p.Header.Extension = true
p.Header.ExtensionProfile = 0xBEDE
assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("read")))

// write back a pli
rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter)
pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC}
_, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes))
assert.NoError(t.t, err)

return p, attributes, nil
})
}

func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) {
pkts, attributes, err := reader.Read()
if err != nil {
return nil, nil, err
}

t.lastRTCP.Store(pkts[0])

return pkts, attributes, nil
a.Set("attribute", "value")
return reader.Read(b, a)
})
}

func (t *testInterceptor) lastReadRTCP() rtcp.Packet {
p, _ := t.lastRTCP.Load().(rtcp.Packet)
return p
}

func (t *testInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter {
t.rtcpWriter.Store(writer)
return writer
}

// E2E test of the features of Interceptors
// * Assert an extension can be set on an outbound packet
// * Assert an extension can be read on an outbound packet
// * Assert that attributes set by an interceptor are returned to the Reader
func TestPeerConnection_Interceptor(t *testing.T) {
to := test.TimeOut(time.Second * 20)
defer to.Stop()

report := test.CheckRoutines(t)
defer report()

createPC := func(i interceptor.Interceptor) *PeerConnection {
createPC := func() *PeerConnection {
m := &MediaEngine{}
assert.NoError(t, m.RegisterDefaultCodecs())

ir := &interceptor.Registry{}
ir.Add(i)
ir.Add(&testInterceptor{t: t})

pc, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{})
assert.NoError(t, err)

return pc
}

sendInterceptor := &testInterceptor{t: t, extensionID: 1}
senderPC := createPC(sendInterceptor)
receiverPC := createPC(&testInterceptor{t: t, extensionID: 2})
offerer := createPC()
answerer := createPC()

track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
assert.NoError(t, err)

sender, err := senderPC.AddTrack(track)
_, err = offerer.AddTrack(track)
assert.NoError(t, err)

pending := new(int32)
wg := &sync.WaitGroup{}

wg.Add(1)
*pending++
receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
p, readErr := track.ReadRTP()
seenRTP, seenRTPCancel := context.WithCancel(context.Background())
answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) {
p, attributes, readErr := track.ReadRTP()
assert.NoError(t, readErr)

assert.Equal(t, p.Extension, true)
assert.Equal(t, "write", string(p.GetExtension(1)))
assert.Equal(t, "read", string(p.GetExtension(2)))
atomic.AddInt32(pending, -1)
wg.Done()
assert.Equal(t, "foo", string(p.GetExtension(2)))
assert.Equal(t, "value", attributes.Get("attribute"))

for {
if _, readErr = track.ReadRTP(); readErr != nil {
return
}
}
seenRTPCancel()
})

wg.Add(1)
*pending++
go func() {
_, readErr := sender.ReadRTCP()
assert.NoError(t, readErr)
atomic.AddInt32(pending, -1)
wg.Done()
assert.NoError(t, signalPair(offerer, answerer))

func() {
ticker := time.NewTicker(time.Millisecond * 20)
for {
if _, readErr = sender.ReadRTCP(); readErr != nil {
select {
case <-seenRTP.Done():
return
case <-ticker.C:
assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))
}
}
}()

assert.NoError(t, signalPair(senderPC, receiverPC))

wg.Add(1)
go func() {
defer wg.Done()
for {
time.Sleep(time.Millisecond * 100)

assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second}))

if atomic.LoadInt32(pending) == 0 {
return
}
}
}()

wg.Wait()
assert.NoError(t, senderPC.Close())
assert.NoError(t, receiverPC.Close())

pli, _ := sendInterceptor.lastReadRTCP().(*rtcp.PictureLossIndication)
if pli == nil || pli.SenderSSRC == 0 {
t.Errorf("pli not found by send interceptor")
}
assert.NoError(t, offerer.Close())
assert.NoError(t, answerer.Close())
}
27 changes: 0 additions & 27 deletions interceptor_track_local.go

This file was deleted.

1 change: 0 additions & 1 deletion peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,6 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
receiver.Track().kind = receiver.kind
receiver.Track().codec = params.Codecs[0]
receiver.Track().params = params
receiver.Track().bindInterceptor()
receiver.Track().mu.Unlock()

pc.onTrack(receiver.Track(), receiver)
Expand Down
Loading

0 comments on commit 67826b1

Please sign in to comment.