Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #367

Merged
merged 11 commits into from
Nov 16, 2023
5 changes: 3 additions & 2 deletions nbhttp/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,11 @@ func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn
isNonblock := true
tlsConn.ResetConn(nbc, isNonblock)

c.conn = tlsConn
nbhttpConn := &Conn{Conn: tlsConn}
c.conn = nbhttpConn
processor := NewClientProcessor(c, c.onResponse)
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
parser.Conn = tlsConn
parser.Conn = nbhttpConn
parser.Engine = engine
parser.OnClose(func(p *Parser, err error) {
c.CloseWithError(err)
Expand Down
120 changes: 68 additions & 52 deletions nbhttp/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,13 @@ func (e *Engine) closeAllConns() {
}
}

func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net.Conn, *tls.Config, func()), decrease func()) {
type Conn struct {
net.Conn
Parser *Parser
Trasfered bool
}

func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(*Conn, *tls.Config, func()), decrease func()) {
e.WaitGroup.Add(1)
go func() {
defer func() {
Expand All @@ -289,7 +295,7 @@ func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net
for !e.shutdown {
conn, err := ln.Accept()
if err == nil && !e.shutdown {
addConn(conn, tlsConfig, decrease)
addConn(&Conn{Conn: conn}, tlsConfig, decrease)
} else {
var ne net.Error
if ok := errors.As(err, &ne); ok && ne.Temporary() {
Expand Down Expand Up @@ -520,31 +526,34 @@ func (e *Engine) TLSDataHandler(c *nbio.Conn, data []byte) {
c.Close()
return
}
if tlsConn, ok := parser.Processor.Conn().(*tls.Conn); ok {
defer tlsConn.ResetOrFreeBuffer()

readed := data
buffer := data
for {
_, nread, err := tlsConn.AppendAndRead(readed, buffer)
readed = nil
if err != nil {
c.CloseWithError(err)
return
}
if nread > 0 {
err := parser.Read(buffer[:nread])
nbhttpConn, ok := parser.Processor.Conn().(*Conn)
if ok {
if tlsConn, ok := nbhttpConn.Conn.(*tls.Conn); ok {
defer tlsConn.ResetOrFreeBuffer()

readed := data
buffer := data
for {
_, nread, err := tlsConn.AppendAndRead(readed, buffer)
readed = nil
if err != nil {
logging.Debug("parser.Read failed: %v", err)
c.CloseWithError(err)
return
}
if nread > 0 {
err := parser.Read(buffer[:nread])
if err != nil {
logging.Debug("parser.Read failed: %v", err)
c.CloseWithError(err)
return
}
}
if nread == 0 {
return
}
}
if nread == 0 {
return
}
// c.SetReadDeadline(time.Now().Add(conf.KeepaliveTime))
}
// c.SetReadDeadline(time.Now().Add(conf.KeepaliveTime))
}
}

Expand Down Expand Up @@ -572,13 +581,14 @@ func (engine *Engine) AddTransferredConn(nbc *nbio.Conn) error {
}

// AddConnNonTLSNonBlocking .
func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(c)
func (engine *Engine) AddConnNonTLSNonBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(conn.Conn)
if err != nil {
c.Close()
conn.Close()
logging.Error("AddConnNonTLSNonBlocking failed: %v", err)
return
}
conn.Conn = nbc
if nbc.Session() != nil {
nbc.Close()
return
Expand All @@ -599,13 +609,14 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config
}
engine.conns[key] = struct{}{}
engine.mux.Unlock()
engine._onOpen(nbc)
processor := NewServerProcessor(nbc, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
engine._onOpen(conn.Conn)
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, nbc.Execute)
if engine.isOneshot {
parser.Execute = SyncExecutor
}
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
nbc.SetSession(parser)
nbc.OnData(engine.DataHandler)
Expand All @@ -614,7 +625,7 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config
}

// AddConnNonTLSBlocking .
func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) {
func (engine *Engine) AddConnNonTLSBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
engine.mux.Lock()
if len(engine.conns) >= engine.MaxLoad {
engine.mux.Unlock()
Expand All @@ -623,7 +634,7 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config
decrease()
return
}
switch vt := conn.(type) {
switch vt := conn.Conn.(type) {
case *net.TCPConn, *net.UnixConn:
key, err := conn2Array(vt)
if err != nil {
Expand All @@ -646,19 +657,21 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, SyncExecutor)
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
conn.SetReadDeadline(time.Now().Add(engine.KeepaliveTime))
go engine.readConnBlocking(conn, parser, decrease)
}

// AddConnTLSNonBlocking .
func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(conn)
func (engine *Engine) AddConnTLSNonBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(conn.Conn)
if err != nil {
conn.Close()
logging.Error("AddConnTLSNonBlocking failed: %v", err)
return
}
conn.Conn = nbc
if nbc.Session() != nil {
nbc.Close()
logging.Error("AddConnTLSNonBlocking failed: session should not be nil")
Expand All @@ -681,18 +694,20 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config

engine.conns[key] = struct{}{}
engine.mux.Unlock()
engine._onOpen(nbc)
engine._onOpen(conn.Conn)

isClient := false
isNonBlock := true
tlsConn := tls.NewConn(nbc, tlsConfig, isClient, isNonBlock, engine.TLSAllocator)
processor := NewServerProcessor(tlsConn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
conn = &Conn{Conn: tlsConn}
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, nbc.Execute)
if engine.isOneshot {
parser.Execute = SyncExecutor
}
parser.Conn = tlsConn
parser.Conn = conn
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
nbc.SetSession(parser)

Expand All @@ -702,7 +717,7 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config
}

// AddConnTLSBlocking .
func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) {
func (engine *Engine) AddConnTLSBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
engine.mux.Lock()
if len(engine.conns) >= engine.MaxLoad {
engine.mux.Unlock()
Expand All @@ -712,7 +727,8 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d
return
}

switch vt := conn.(type) {
underLayerConn := conn.Conn
switch vt := underLayerConn.(type) {
case *net.TCPConn, *net.UnixConn:
key, err := conn2Array(vt)
if err != nil {
Expand All @@ -735,18 +751,20 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d

isClient := false
isNonBlock := true
tlsConn := tls.NewConn(conn, tlsConfig, isClient, isNonBlock, engine.TLSAllocator)
processor := NewServerProcessor(tlsConn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
tlsConn := tls.NewConn(underLayerConn, tlsConfig, isClient, isNonBlock, engine.TLSAllocator)
conn = &Conn{Conn: tlsConn}
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, SyncExecutor)
parser.Conn = tlsConn
parser.Conn = conn
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
conn.SetReadDeadline(time.Now().Add(engine.KeepaliveTime))
tlsConn.SetSession(parser)
go engine.readTLSConnBlocking(conn, tlsConn, parser, decrease)
go engine.readTLSConnBlocking(conn, underLayerConn, tlsConn, parser, decrease)
}

func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease func()) {
func (engine *Engine) readConnBlocking(conn *Conn, parser *Parser, decrease func()) {
var (
n int
err error
Expand All @@ -764,7 +782,7 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f
// go func() {
parser.Close(err)
engine.mux.Lock()
switch vt := conn.(type) {
switch vt := conn.Conn.(type) {
case *net.TCPConn, *net.UnixConn:
key, _ := conn2Array(vt)
delete(engine.conns, key)
Expand All @@ -781,13 +799,10 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f
return
}
parser.Read(buf[:n])
if parser.hijacked {
return
}
}
}

func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, parser *Parser, decrease func()) {
func (engine *Engine) readTLSConnBlocking(conn *Conn, rconn net.Conn, tlsConn *tls.Conn, parser *Parser, decrease func()) {
var (
err error
nread int
Expand All @@ -801,10 +816,13 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
buffer := readBufferPool.Malloc(engine.BlockingReadBufferSize)
defer func() {
readBufferPool.Free(buffer)
parser.Close(err)
tlsConn.Close()
if !conn.Trasfered {
parser.Close(err)
tlsConn.Close()
}

engine.mux.Lock()
switch vt := conn.(type) {
switch vt := rconn.(type) {
case *net.TCPConn, *net.UnixConn:
key, _ := conn2Array(vt)
delete(engine.conns, key)
Expand All @@ -815,7 +833,7 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
}()

for {
nread, err = conn.Read(buffer)
nread, err = rconn.Read(buffer)
if err != nil {
return
}
Expand All @@ -833,9 +851,6 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
logging.Debug("parser.Read failed: %v", err)
return
}
// if parser.hijacked {
// return
// }
}
if nread == 0 {
break
Expand Down Expand Up @@ -1011,6 +1026,7 @@ func NewEngine(conf Config) *Engine {
engine.mux.Lock()
key, _ := conn2Array(c)
delete(engine.conns, key)
delete(engine.dialerConns, key)
engine.mux.Unlock()
})
})
Expand Down
11 changes: 5 additions & 6 deletions nbhttp/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ type Parser struct {

cache []byte

state int8
isClient bool
hijacked bool

readLimit int

errClose error
Expand Down Expand Up @@ -66,8 +62,11 @@ type Parser struct {
trailer http.Header
contentLength int
chunkSize int
chunked bool
headerExists bool

state int8
chunked bool
isClient bool
headerExists bool
}

func (p *Parser) nextState(state int8) {
Expand Down
4 changes: 1 addition & 3 deletions nbhttp/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,9 @@ func (p *ServerProcessor) OnComplete(parser *Parser) {
}

func (p *ServerProcessor) flushResponse(res *Response) {
hijacked := res.hijacked
p.parser.hijacked = hijacked
if p.conn != nil {
req := res.request
if !hijacked {
if !res.hijacked {
res.eoncodeHead()
if err := res.flushTrailer(p.conn); err != nil {
p.conn.Close()
Expand Down
2 changes: 1 addition & 1 deletion nbhttp/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ func NewConn(u *Upgrader, c net.Conn, subprotocol string, remoteCompressionEnabl

// HandleRead .
func (c *Conn) HandleRead(bufSize int) {
if !c.isReadingByParser {
if c.isReadingByParser {
return
}
c.mux.Lock()
Expand Down
8 changes: 7 additions & 1 deletion nbhttp/websocket/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,13 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h

nbc, ok := conn.(*nbio.Conn)
if !ok {
tlsConn, tlsOk := conn.(*tls.Conn)
nbhttpConn, ok2 := conn.(*nbhttp.Conn)
if !ok2 {
err = ErrBadHandshake
notifyResult(err)
return
}
tlsConn, tlsOk := nbhttpConn.Conn.(*tls.Conn)
if !tlsOk {
err = ErrBadHandshake
notifyResult(err)
Expand Down
Loading
Loading