Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for sending error codes on stream reset #122

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@
return false
}

// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
type StreamError struct {
ErrorCode uint32
Remote bool
}

func (s *StreamError) Error() string {
if s.Remote {
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)

Check warning on line 68 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L66-L68

Added lines #L66 - L68 were not covered by tests
}
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)

Check warning on line 70 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L70

Added line #L70 was not covered by tests
}

func (s *StreamError) Is(target error) bool {
if target == ErrStreamReset {
return true
}
e, ok := target.(*StreamError)
return ok && *e == *s

Check warning on line 78 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L77-L78

Added lines #L77 - L78 were not covered by tests
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
Expand Down Expand Up @@ -152,7 +173,7 @@
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
goAwayWaitTime = 5 * time.Second
goAwayWaitTime = 50 * time.Millisecond
)

const (
Expand Down
2 changes: 1 addition & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose()
stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr))
delete(s.streams, id)
stream.memorySpan.Done()
}
Expand Down
53 changes: 52 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) {
wc.Wait()
}

func TestStreamResetWithError(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

wc := new(sync.WaitGroup)
wc.Add(2)
go func() {
defer wc.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Error(err)
}

se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: true, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

stream, err := client.OpenStream(context.Background())
if err != nil {
t.Error(err)
}

go func() {
defer wc.Done()

se := &StreamError{}
_, err := io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

time.Sleep(1 * time.Second)
err = stream.ResetWithError(42)
if err != nil {
t.Fatal(err)
}
wc.Wait()
}

func TestLotsOfWritesWithStreamDeadline(t *testing.T) {
config := testConf()
config.EnableKeepAlive = false
Expand Down Expand Up @@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) {
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.EqualError(t, err, "stream reset")
require.ErrorIs(t, err, ErrStreamReset)

// Now close one of the streams.
// This should then allow the client to open a new stream.
Expand Down
37 changes: 28 additions & 9 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Stream struct {

state streamState
writeState, readState halfStreamState
writeErr, readErr error
stateLock sync.Mutex

recvBuf segmentedBuffer
Expand Down Expand Up @@ -89,6 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) {
START:
s.stateLock.Lock()
state := s.readState
resetErr := s.readErr
s.stateLock.Unlock()

switch state {
Expand All @@ -101,7 +103,7 @@ START:
}
// Closed, but we have data pending -> read.
case halfReset:
return 0, ErrStreamReset
return 0, resetErr
default:
panic("unknown state")
}
Expand Down Expand Up @@ -147,6 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) {
START:
s.stateLock.Lock()
state := s.writeState
resetErr := s.writeErr
s.stateLock.Unlock()

switch state {
Expand All @@ -155,7 +158,7 @@ START:
case halfClosed:
return 0, ErrStreamClosed
case halfReset:
return 0, ErrStreamReset
return 0, resetErr
default:
panic("unknown state")
}
Expand Down Expand Up @@ -250,13 +253,17 @@ func (s *Stream) sendClose() error {
}

// sendReset is used to send a RST
func (s *Stream) sendReset() error {
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
func (s *Stream) sendReset(errCode uint32) error {
hdr := encode(typeWindowUpdate, flagRST, s.id, errCode)
return s.session.sendMsg(hdr, nil, nil)
}

// Reset resets the stream (forcibly closes the stream)
func (s *Stream) Reset() error {
return s.ResetWithError(0)
}

func (s *Stream) ResetWithError(errCode uint32) error {
sendReset := false
s.stateLock.Lock()
switch s.state {
Expand All @@ -276,15 +283,17 @@ func (s *Stream) Reset() error {
// If we've already sent/received an EOF, no need to reset that side.
if s.writeState == halfOpen {
s.writeState = halfReset
s.writeErr = &StreamError{Remote: false, ErrorCode: errCode}
}
if s.readState == halfOpen {
s.readState = halfReset
s.readErr = &StreamError{Remote: false, ErrorCode: errCode}
}
s.state = streamFinished
s.notifyWaiting()
s.stateLock.Unlock()
if sendReset {
_ = s.sendReset()
_ = s.sendReset(errCode)
}
s.cleanup()
return nil
Expand Down Expand Up @@ -336,6 +345,7 @@ func (s *Stream) CloseRead() error {
panic("invalid state")
}
s.readState = halfReset
s.readErr = ErrStreamReset
cleanup = s.writeState != halfOpen
if cleanup {
s.state = streamFinished
Expand All @@ -357,13 +367,15 @@ func (s *Stream) Close() error {
}

// forceClose is used for when the session is exiting
func (s *Stream) forceClose() {
func (s *Stream) forceClose(err error) {
s.stateLock.Lock()
if s.readState == halfOpen {
s.readState = halfReset
s.readErr = err
}
if s.writeState == halfOpen {
s.writeState = halfReset
s.writeErr = err
}
s.state = streamFinished
s.notifyWaiting()
Expand All @@ -382,7 +394,7 @@ func (s *Stream) cleanup() {

// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16) {
func (s *Stream) processFlags(flags uint16, hdr header) {
// Close the stream without holding the state lock
var closeStream bool
defer func() {
Expand Down Expand Up @@ -418,11 +430,18 @@ func (s *Stream) processFlags(flags uint16) {
}
if flags&flagRST == flagRST {
s.stateLock.Lock()
var resetErr error = ErrStreamReset
// Length in a window update frame with RST flag encodes an error code.
if hdr.MsgType() == typeWindowUpdate {
resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()}
}
if s.readState == halfOpen {
s.readState = halfReset
s.readErr = resetErr
}
if s.writeState == halfOpen {
s.writeState = halfReset
s.writeErr = resetErr
}
s.state = streamFinished
s.stateLock.Unlock()
Expand All @@ -439,15 +458,15 @@ func (s *Stream) notifyWaiting() {

// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
s.processFlags(flags)
s.processFlags(flags, hdr)
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
}

// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
s.processFlags(flags)
s.processFlags(flags, hdr)

// Check that our recv window is not exceeded
length := hdr.Length()
Expand Down