diff --git a/nbhttp/client_conn.go b/nbhttp/client_conn.go index 173b0190..8b552f68 100644 --- a/nbhttp/client_conn.go +++ b/nbhttp/client_conn.go @@ -318,10 +318,11 @@ func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn isNonblock := true tlsConn.ResetConn(nbc, isNonblock) - c.conn = tlsConn + nbhttpConn := &Conn{Conn: tlsConn} + c.conn = nbhttpConn processor := NewClientProcessor(c, c.onResponse) parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute) - parser.Conn = tlsConn + parser.Conn = nbhttpConn parser.Engine = engine parser.OnClose(func(p *Parser, err error) { c.CloseWithError(err) diff --git a/nbhttp/engine.go b/nbhttp/engine.go index 2060b304..c0a2fb76 100644 --- a/nbhttp/engine.go +++ b/nbhttp/engine.go @@ -279,7 +279,13 @@ func (e *Engine) closeAllConns() { } } -func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net.Conn, *tls.Config, func()), decrease func()) { +type Conn struct { + net.Conn + Parser *Parser + Trasfered bool +} + +func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(*Conn, *tls.Config, func()), decrease func()) { e.WaitGroup.Add(1) go func() { defer func() { @@ -289,7 +295,7 @@ func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net for !e.shutdown { conn, err := ln.Accept() if err == nil && !e.shutdown { - addConn(conn, tlsConfig, decrease) + addConn(&Conn{Conn: conn}, tlsConfig, decrease) } else { var ne net.Error if ok := errors.As(err, &ne); ok && ne.Temporary() { @@ -520,31 +526,34 @@ func (e *Engine) TLSDataHandler(c *nbio.Conn, data []byte) { c.Close() return } - if tlsConn, ok := parser.Processor.Conn().(*tls.Conn); ok { - defer tlsConn.ResetOrFreeBuffer() - - readed := data - buffer := data - for { - _, nread, err := tlsConn.AppendAndRead(readed, buffer) - readed = nil - if err != nil { - c.CloseWithError(err) - return - } - if nread > 0 { - err := parser.Read(buffer[:nread]) + nbhttpConn, ok := parser.Processor.Conn().(*Conn) + if ok { + if tlsConn, ok := nbhttpConn.Conn.(*tls.Conn); ok { + defer tlsConn.ResetOrFreeBuffer() + + readed := data + buffer := data + for { + _, nread, err := tlsConn.AppendAndRead(readed, buffer) + readed = nil if err != nil { - logging.Debug("parser.Read failed: %v", err) c.CloseWithError(err) return } + if nread > 0 { + err := parser.Read(buffer[:nread]) + if err != nil { + logging.Debug("parser.Read failed: %v", err) + c.CloseWithError(err) + return + } + } + if nread == 0 { + return + } } - if nread == 0 { - return - } + // c.SetReadDeadline(time.Now().Add(conf.KeepaliveTime)) } - // c.SetReadDeadline(time.Now().Add(conf.KeepaliveTime)) } } @@ -572,13 +581,14 @@ func (engine *Engine) AddTransferredConn(nbc *nbio.Conn) error { } // AddConnNonTLSNonBlocking . -func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config, decrease func()) { - nbc, err := nbio.NBConn(c) +func (engine *Engine) AddConnNonTLSNonBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) { + nbc, err := nbio.NBConn(conn.Conn) if err != nil { - c.Close() + conn.Close() logging.Error("AddConnNonTLSNonBlocking failed: %v", err) return } + conn.Conn = nbc if nbc.Session() != nil { nbc.Close() return @@ -599,13 +609,14 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config } engine.conns[key] = struct{}{} engine.mux.Unlock() - engine._onOpen(nbc) - processor := NewServerProcessor(nbc, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) + engine._onOpen(conn.Conn) + processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) parser := NewParser(processor, false, engine.ReadLimit, nbc.Execute) if engine.isOneshot { parser.Execute = SyncExecutor } parser.Engine = engine + conn.Parser = parser processor.(*ServerProcessor).parser = parser nbc.SetSession(parser) nbc.OnData(engine.DataHandler) @@ -614,7 +625,7 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config } // AddConnNonTLSBlocking . -func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) { +func (engine *Engine) AddConnNonTLSBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) { engine.mux.Lock() if len(engine.conns) >= engine.MaxLoad { engine.mux.Unlock() @@ -623,7 +634,7 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config decrease() return } - switch vt := conn.(type) { + switch vt := conn.Conn.(type) { case *net.TCPConn, *net.UnixConn: key, err := conn2Array(vt) if err != nil { @@ -646,19 +657,21 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) parser := NewParser(processor, false, engine.ReadLimit, SyncExecutor) parser.Engine = engine + conn.Parser = parser processor.(*ServerProcessor).parser = parser conn.SetReadDeadline(time.Now().Add(engine.KeepaliveTime)) go engine.readConnBlocking(conn, parser, decrease) } // AddConnTLSNonBlocking . -func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) { - nbc, err := nbio.NBConn(conn) +func (engine *Engine) AddConnTLSNonBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) { + nbc, err := nbio.NBConn(conn.Conn) if err != nil { conn.Close() logging.Error("AddConnTLSNonBlocking failed: %v", err) return } + conn.Conn = nbc if nbc.Session() != nil { nbc.Close() logging.Error("AddConnTLSNonBlocking failed: session should not be nil") @@ -681,18 +694,20 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config engine.conns[key] = struct{}{} engine.mux.Unlock() - engine._onOpen(nbc) + engine._onOpen(conn.Conn) isClient := false isNonBlock := true tlsConn := tls.NewConn(nbc, tlsConfig, isClient, isNonBlock, engine.TLSAllocator) - processor := NewServerProcessor(tlsConn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) + conn = &Conn{Conn: tlsConn} + processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) parser := NewParser(processor, false, engine.ReadLimit, nbc.Execute) if engine.isOneshot { parser.Execute = SyncExecutor } - parser.Conn = tlsConn + parser.Conn = conn parser.Engine = engine + conn.Parser = parser processor.(*ServerProcessor).parser = parser nbc.SetSession(parser) @@ -702,7 +717,7 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config } // AddConnTLSBlocking . -func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) { +func (engine *Engine) AddConnTLSBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) { engine.mux.Lock() if len(engine.conns) >= engine.MaxLoad { engine.mux.Unlock() @@ -712,7 +727,8 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d return } - switch vt := conn.(type) { + underLayerConn := conn.Conn + switch vt := underLayerConn.(type) { case *net.TCPConn, *net.UnixConn: key, err := conn2Array(vt) if err != nil { @@ -735,18 +751,20 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d isClient := false isNonBlock := true - tlsConn := tls.NewConn(conn, tlsConfig, isClient, isNonBlock, engine.TLSAllocator) - processor := NewServerProcessor(tlsConn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) + tlsConn := tls.NewConn(underLayerConn, tlsConfig, isClient, isNonBlock, engine.TLSAllocator) + conn = &Conn{Conn: tlsConn} + processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile) parser := NewParser(processor, false, engine.ReadLimit, SyncExecutor) - parser.Conn = tlsConn + parser.Conn = conn parser.Engine = engine + conn.Parser = parser processor.(*ServerProcessor).parser = parser conn.SetReadDeadline(time.Now().Add(engine.KeepaliveTime)) tlsConn.SetSession(parser) - go engine.readTLSConnBlocking(conn, tlsConn, parser, decrease) + go engine.readTLSConnBlocking(conn, underLayerConn, tlsConn, parser, decrease) } -func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease func()) { +func (engine *Engine) readConnBlocking(conn *Conn, parser *Parser, decrease func()) { var ( n int err error @@ -764,7 +782,7 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f // go func() { parser.Close(err) engine.mux.Lock() - switch vt := conn.(type) { + switch vt := conn.Conn.(type) { case *net.TCPConn, *net.UnixConn: key, _ := conn2Array(vt) delete(engine.conns, key) @@ -781,13 +799,10 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f return } parser.Read(buf[:n]) - if parser.hijacked { - return - } } } -func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, parser *Parser, decrease func()) { +func (engine *Engine) readTLSConnBlocking(conn *Conn, rconn net.Conn, tlsConn *tls.Conn, parser *Parser, decrease func()) { var ( err error nread int @@ -801,10 +816,13 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars buffer := readBufferPool.Malloc(engine.BlockingReadBufferSize) defer func() { readBufferPool.Free(buffer) - parser.Close(err) - tlsConn.Close() + if !conn.Trasfered { + parser.Close(err) + tlsConn.Close() + } + engine.mux.Lock() - switch vt := conn.(type) { + switch vt := rconn.(type) { case *net.TCPConn, *net.UnixConn: key, _ := conn2Array(vt) delete(engine.conns, key) @@ -815,7 +833,7 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars }() for { - nread, err = conn.Read(buffer) + nread, err = rconn.Read(buffer) if err != nil { return } @@ -833,9 +851,6 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars logging.Debug("parser.Read failed: %v", err) return } - // if parser.hijacked { - // return - // } } if nread == 0 { break @@ -1011,6 +1026,7 @@ func NewEngine(conf Config) *Engine { engine.mux.Lock() key, _ := conn2Array(c) delete(engine.conns, key) + delete(engine.dialerConns, key) engine.mux.Unlock() }) }) diff --git a/nbhttp/parser.go b/nbhttp/parser.go index 9050635c..92c45722 100644 --- a/nbhttp/parser.go +++ b/nbhttp/parser.go @@ -33,10 +33,6 @@ type Parser struct { cache []byte - state int8 - isClient bool - hijacked bool - readLimit int errClose error @@ -66,8 +62,11 @@ type Parser struct { trailer http.Header contentLength int chunkSize int - chunked bool - headerExists bool + + state int8 + chunked bool + isClient bool + headerExists bool } func (p *Parser) nextState(state int8) { diff --git a/nbhttp/processor.go b/nbhttp/processor.go index be0c2a85..1f271ecb 100644 --- a/nbhttp/processor.go +++ b/nbhttp/processor.go @@ -273,11 +273,9 @@ func (p *ServerProcessor) OnComplete(parser *Parser) { } func (p *ServerProcessor) flushResponse(res *Response) { - hijacked := res.hijacked - p.parser.hijacked = hijacked if p.conn != nil { req := res.request - if !hijacked { + if !res.hijacked { res.eoncodeHead() if err := res.flushTrailer(p.conn); err != nil { p.conn.Close() diff --git a/nbhttp/websocket/conn.go b/nbhttp/websocket/conn.go index 2ffbf6f0..d9c03fb3 100644 --- a/nbhttp/websocket/conn.go +++ b/nbhttp/websocket/conn.go @@ -831,7 +831,7 @@ func NewConn(u *Upgrader, c net.Conn, subprotocol string, remoteCompressionEnabl // HandleRead . func (c *Conn) HandleRead(bufSize int) { - if !c.isReadingByParser { + if c.isReadingByParser { return } c.mux.Lock() diff --git a/nbhttp/websocket/dialer.go b/nbhttp/websocket/dialer.go index d1f9ab2f..7dbf0860 100644 --- a/nbhttp/websocket/dialer.go +++ b/nbhttp/websocket/dialer.go @@ -185,7 +185,13 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h nbc, ok := conn.(*nbio.Conn) if !ok { - tlsConn, tlsOk := conn.(*tls.Conn) + nbhttpConn, ok2 := conn.(*nbhttp.Conn) + if !ok2 { + err = ErrBadHandshake + notifyResult(err) + return + } + tlsConn, tlsOk := nbhttpConn.Conn.(*tls.Conn) if !tlsOk { err = ErrBadHandshake notifyResult(err) diff --git a/nbhttp/websocket/upgrader.go b/nbhttp/websocket/upgrader.go index 0e7f01cc..f30c7b6e 100644 --- a/nbhttp/websocket/upgrader.go +++ b/nbhttp/websocket/upgrader.go @@ -235,7 +235,16 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } } - switch vt := conn.(type) { + var underLayerConn net.Conn + nbhttpConn, isReadingByParser := conn.(*nbhttp.Conn) + if isReadingByParser { + underLayerConn = nbhttpConn.Conn + parser = nbhttpConn.Parser + } else { + underLayerConn = conn + } + + switch vt := underLayerConn.(type) { case *nbio.Conn: // Scenario 1: *nbio.Conn, handled by nbhttp.Engine. parser, ok = vt.Session().(*nbhttp.Parser) @@ -257,6 +266,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if err != nil { return nil, u.returnError(w, r, http.StatusInternalServerError, err) } + if nbhttpConn != nil { + nbhttpConn.Trasfered = true + } vt.ResetRawInput() parser = &nbhttp.Parser{Execute: nbc.Execute} if engine.EpollMod == nbio.EPOLLET && engine.EPOLLONESHOT == nbio.EPOLLONESHOT { @@ -329,6 +341,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if err != nil { return nil, u.returnError(w, r, http.StatusInternalServerError, err) } + if nbhttpConn != nil { + nbhttpConn.Trasfered = true + } parser = &nbhttp.Parser{Execute: nbc.Execute} if engine.EpollMod == nbio.EPOLLET && engine.EPOLLONESHOT == nbio.EPOLLONESHOT { parser.Execute = nbhttp.SyncExecutor @@ -374,13 +389,15 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return nil, err } - wsc.isReadingByParser = (parser == nil) - if wsc.openHandler != nil { wsc.openHandler(wsc) } - if wsc.isBlockingMod && wsc.isReadingByParser { + if parser != nil { + parser.Reader = wsc + } + wsc.isReadingByParser = isReadingByParser + if wsc.isBlockingMod && (!wsc.isReadingByParser) { var handleRead = true if len(args) > 1 { var b bool diff --git a/poller_kqueue.go b/poller_kqueue.go index 3cee88d0..22a6c1b6 100644 --- a/poller_kqueue.go +++ b/poller_kqueue.go @@ -103,28 +103,30 @@ func (p *poller) trigger() { func (p *poller) addRead(fd int) { p.mux.Lock() - p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD | syscall.EV_CLEAR, Filter: syscall.EVFILT_READ}) + p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_READ}) p.mux.Unlock() p.trigger() } func (p *poller) resetRead(fd int) { p.mux.Lock() - p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DISABLE, Filter: syscall.EVFILT_WRITE}) + p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_WRITE}) p.mux.Unlock() p.trigger() } func (p *poller) modWrite(fd int) { p.mux.Lock() - p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD | syscall.EV_CLEAR, Filter: syscall.EVFILT_WRITE}) + p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_ADD, Filter: syscall.EVFILT_WRITE}) p.mux.Unlock() p.trigger() } func (p *poller) deleteEvent(fd int) { p.mux.Lock() - p.eventList = append(p.eventList, syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_READ}) + p.eventList = append(p.eventList, + syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_READ}, + syscall.Kevent_t{Ident: uint64(fd), Flags: syscall.EV_DELETE, Filter: syscall.EVFILT_WRITE}) p.mux.Unlock() p.trigger() }