diff --git a/data_reader.go b/data_reader.go index 6367df1..7ea5618 100644 --- a/data_reader.go +++ b/data_reader.go @@ -1,17 +1,27 @@ package telnet - import ( - "bufio" + "bytes" "errors" + "fmt" "io" ) - var ( errCorrupted = errors.New("Corrupted") ) +const ( + IAC = 255 + + SB = 250 + SE = 240 + + WILL = 251 + WONT = 252 + DO = 253 + DONT = 254 +) // An internalDataReader deals with "un-escaping" according to the TELNET protocol. // @@ -56,118 +66,180 @@ var ( // // []byte{1, 55, 2, 155, 3, 255, 4, 40, 255, 30, 20} type internalDataReader struct { - wrapped io.Reader - buffered *bufio.Reader + wrapped io.Reader + state state } - // newDataReader creates a new DataReader reading from 'r'. func newDataReader(r io.Reader) *internalDataReader { - buffered := bufio.NewReader(r) - reader := internalDataReader{ - wrapped:r, - buffered:buffered, + wrapped: r, + state: copyData, } return &reader } +// Read reads the TELNET escaped data from the wrapped io.Reader, and "un-escapes" it into 'data'. +// It executes exactly one Read on the underlying reader every time it is called. +// Callers should be careful to truncate data to the number of bytes read, +// since this reader is expected to drop bytes from the underlying reader +// as described above when required to by the TELNET protocol. +func (r *internalDataReader) Read(data []byte) (int, error) { + mach := &machine{ + from: make([]byte, len(data)), + to: data, + } + n, err := r.wrapped.Read(mach.from) + if err != nil { + return 0, err + } + mach.from = mach.from[:n] + for err == nil && mach.InputRemaining() { + r.state, err = r.state(mach) + } + return mach.written, err +} + +// Unescaping of data read from the underlying reader is done using +// a state machine so it is resumable across reads. +type machine struct { + from, to []byte + read, written int +} + +// Index returns the offset from the read pointer of the first occurrence +// of the byte b, or -1 if that byte is not found in the remainder of the +// data read from the underlying reader. +func (m *machine) Index(b byte) int { + return bytes.Index(m.from[m.read:], []byte{b}) +} + +// Copy copies up to n bytes from the underlying reader to the destination +// buffer, advancing both the read and write pointers by this amount. +func (m *machine) Copy(n int) { + // Deliberately no bounds check here, because asking this + // code to read past the end of m.from should never happen. + copied := copy(m.to[m.written:], m.from[m.read:m.read+n]) + m.written += copied + m.read += copied +} + +// WriteByte writes the provided byte to the destnation buffer and advances +// the write pointer. +func (m *machine) WriteByte(b byte) { + m.to[m.written] = b + m.written++ +} + +// ConsumeByte reads and returns the next byte from the underlying reader, +// advancing the read pointer. +func (m *machine) ConsumeByte() byte { + b := m.from[m.read] + m.read++ + return b +} + +// InputRemaining returns true as long as there is still data available +// to read from the input buffer. +func (m *machine) InputRemaining() bool { + if m.read >= len(m.from) { + return false + } + return true +} + +// State machine states are functions that take the machine and +// return new states and optionally errors. +type state func(*machine) (state, error) + +// The copyData state copies data from machine.from to machine.to +// until it encounters an IAC byte or the end of from. +func copyData(mach *machine) (state, error) { + idx := mach.Index(IAC) + if idx < 0 { + // No escape bytes, so just copy remaining data and return. + mach.Copy(len(mach.from) - mach.read) + return copyData, nil + } + // Copy data up to IAC. + mach.Copy(idx) + return consumeIAC, nil +} + +// The consumeIAC state eats an IAC byte and returns consumeCmd. +func consumeIAC(mach *machine) (state, error) { + if b := mach.ConsumeByte(); b != IAC { + return copyData, fmt.Errorf("expected IAC byte, got %c", b) + } + return consumeCmd, nil +} -// Read reads the TELNET escaped data from the wrapped io.Reader, and "un-escapes" it into 'data'. -func (r *internalDataReader) Read(data []byte) (n int, err error) { - - const IAC = 255 - - const SB = 250 - const SE = 240 - - const WILL = 251 - const WONT = 252 - const DO = 253 - const DONT = 254 - - p := data - - for len(p) > 0 { - var b byte - - b, err = r.buffered.ReadByte() - if nil != err { - return n, err - } - - if IAC == b { - var peeked []byte - - peeked, err = r.buffered.Peek(1) - if nil != err { - return n, err - } - - switch peeked[0] { - case WILL, WONT, DO, DONT: - _, err = r.buffered.Discard(2) - if nil != err { - return n, err - } - case IAC: - p[0] = IAC - n++ - p = p[1:] - - _, err = r.buffered.Discard(1) - if nil != err { - return n, err - } - case SB: - for { - var b2 byte - b2, err = r.buffered.ReadByte() - if nil != err { - return n, err - } - - if IAC == b2 { - peeked, err = r.buffered.Peek(1) - if nil != err { - return n, err - } - - if IAC == peeked[0] { - _, err = r.buffered.Discard(1) - if nil != err { - return n, err - } - } - - if SE == peeked[0] { - _, err = r.buffered.Discard(1) - if nil != err { - return n, err - } - break - } - } - } - case SE: - _, err = r.buffered.Discard(1) - if nil != err { - return n, err - } - default: - // If we get in here, this is not following the TELNET protocol. -//@TODO: Make a better error. - err = errCorrupted - return n, err - } - } else { - - p[0] = b - n++ - p = p[1:] - } +// The consumeCmd state eats one of the known telnet command bytes. +func consumeCmd(mach *machine) (state, error) { + switch b := mach.ConsumeByte(); b { + case WILL, WONT, DO, DONT: + // WILL, WONT, DO and DONT have an extra command byte + // that shouldn't make it to the output slice. + // We need to consume it before going back to copying data. + return consumeWWDD, nil + case IAC: + // IAC IAC => un-escape; write IAC to output + // and go back to copying data. + mach.WriteByte(IAC) + return copyData, nil + case SB: + // IAC SB => switch to consuming status. + return consumeStatus, nil + case SE: + // IAC SE => go back to copying data. + return copyData, nil + default: + // IAC is a protocol error. + return copyData, fmt.Errorf("expected command byte, got %c", b) } +} + +// The consumeWWDD state eats one byte then resumes copying data. +func consumeWWDD(mach *machine) (state, error) { + mach.ConsumeByte() + return copyData, nil +} + +// The consumeStatus state eats data until it encounters an IAC +// byte or the end of from. +func consumeStatus(mach *machine) (state, error) { + // We don't try to understand the status commands, + // we just strip them from the output, which means + // dropping input data until we read IAC SE. + idx := mach.Index(IAC) + if idx < 0 { + // No escape bytes, so just skip remaining data and return. + mach.read = len(mach.from) + return consumeStatus, nil + } + // Skip up to IAC. + mach.read += idx + return consumeStatusIAC, nil +} - return n, nil +// The consumeStatusIAC state eats an IAC byte and returns consumeStatusCmd. +func consumeStatusIAC(mach *machine) (state, error) { + if b := mach.ConsumeByte(); b != IAC { + return consumeStatus, fmt.Errorf("expected IAC byte, got %c", b) + } + return consumeStatusCmd, nil +} + +// The consumeStatusCmd state eats a byte. If that byte is SE the machine +// goes back to copying data, otherwise it goes back to consuming status. +func consumeStatusCmd(mach *machine) (state, error) { + switch b := mach.ConsumeByte(); b { + case SE: + // IAC SE => go back to copying data normally + return copyData, nil + default: + // IAC => continue eating SB + return consumeStatus, nil + } } diff --git a/echo_handler.go b/echo_handler.go index 0401009..57c26cf 100644 --- a/echo_handler.go +++ b/echo_handler.go @@ -1,33 +1,13 @@ package telnet - -import ( - "github.com/reiver/go-oi" -) - +import "io" // EchoHandler is a simple TELNET server which "echos" back to the client any (non-command) // data back to the TELNET client, it received from the TELNET client. var EchoHandler Handler = internalEchoHandler{} - type internalEchoHandler struct{} - func (handler internalEchoHandler) ServeTELNET(ctx Context, w Writer, r Reader) { - - var buffer [1]byte // Seems like the length of the buffer needs to be small, otherwise will have to wait for buffer to fill up. - p := buffer[:] - - for { - n, err := r.Read(p) - - if n > 0 { - oi.LongWrite(w, p[:n]) - } - - if nil != err { - break - } - } + io.Copy(w, r) } diff --git a/standard_caller.go b/standard_caller.go index 17a9408..5db4e49 100644 --- a/standard_caller.go +++ b/standard_caller.go @@ -1,6 +1,5 @@ package telnet - import ( "github.com/reiver/go-oi" @@ -13,47 +12,27 @@ import ( "time" ) - // StandardCaller is a simple TELNET client which sends to the server any data it gets from os.Stdin // as TELNET (and TELNETS) data, and writes any TELNET (or TELNETS) data it receives from // the server to os.Stdout, and writes any error it has to os.Stderr. var StandardCaller Caller = internalStandardCaller{} - type internalStandardCaller struct{} - func (caller internalStandardCaller) CallTELNET(ctx Context, w Writer, r Reader) { standardCallerCallTELNET(os.Stdin, os.Stdout, os.Stderr, ctx, w, r) } - func standardCallerCallTELNET(stdin io.ReadCloser, stdout io.WriteCloser, stderr io.WriteCloser, ctx Context, w Writer, r Reader) { go func(writer io.Writer, reader io.Reader) { - - var buffer [1]byte // Seems like the length of the buffer needs to be small, otherwise will have to wait for buffer to fill up. - p := buffer[:] - - for { - // Read 1 byte. - n, err := reader.Read(p) - if n <= 0 && nil == err { - continue - } else if n <= 0 && nil != err { - break - } - - oi.LongWrite(writer, p) - } + io.Copy(writer, reader) }(stdout, r) - - var buffer bytes.Buffer var p []byte - var crlfBuffer [2]byte = [2]byte{'\r','\n'} + var crlfBuffer [2]byte = [2]byte{'\r', '\n'} crlf := crlfBuffer[:] scanner := bufio.NewScanner(stdin) @@ -75,7 +54,6 @@ func standardCallerCallTELNET(stdin io.ReadCloser, stdout io.WriteCloser, stderr return } - buffer.Reset() } @@ -83,7 +61,6 @@ func standardCallerCallTELNET(stdin io.ReadCloser, stdout io.WriteCloser, stderr time.Sleep(3 * time.Millisecond) } - func scannerSplitFunc(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF { return 0, nil, nil diff --git a/telsh/telnet_handler.go b/telsh/telnet_handler.go index ca08c7b..9ff89d4 100644 --- a/telsh/telnet_handler.go +++ b/telsh/telnet_handler.go @@ -1,7 +1,8 @@ package telsh - import ( + "bufio" + "github.com/reiver/go-oi" "github.com/reiver/go-telnet" @@ -11,7 +12,6 @@ import ( "sync" ) - const ( defaultExitCommandName = "exit" defaultPrompt = "ยง " @@ -19,10 +19,9 @@ const ( defaultExitMessage = "\r\nGoodbye!\r\n" ) - type ShellHandler struct { - muxtex sync.RWMutex - producers map[string]Producer + muxtex sync.RWMutex + producers map[string]Producer elseProducer Producer ExitCommandName string @@ -31,12 +30,11 @@ type ShellHandler struct { ExitMessage string } - func NewShellHandler() *ShellHandler { producers := map[string]Producer{} telnetHandler := ShellHandler{ - producers:producers, + producers: producers, Prompt: defaultPrompt, ExitCommandName: defaultExitCommandName, @@ -47,7 +45,6 @@ func NewShellHandler() *ShellHandler { return &telnetHandler } - func (telnetHandler *ShellHandler) Register(name string, producer Producer) error { telnetHandler.muxtex.Lock() @@ -65,7 +62,6 @@ func (telnetHandler *ShellHandler) MustRegister(name string, producer Producer) return telnetHandler } - func (telnetHandler *ShellHandler) RegisterHandlerFunc(name string, handlerFunc HandlerFunc) error { produce := func(ctx telnet.Context, name string, args ...string) Handler { @@ -85,7 +81,6 @@ func (telnetHandler *ShellHandler) MustRegisterHandlerFunc(name string, handlerF return telnetHandler } - func (telnetHandler *ShellHandler) RegisterElse(producer Producer) error { telnetHandler.muxtex.Lock() @@ -97,13 +92,12 @@ func (telnetHandler *ShellHandler) RegisterElse(producer Producer) error { func (telnetHandler *ShellHandler) MustRegisterElse(producer Producer) *ShellHandler { if err := telnetHandler.RegisterElse(producer); nil != err { - panic(err) + panic(err) } return telnetHandler } - func (telnetHandler *ShellHandler) ServeTELNET(ctx telnet.Context, writer telnet.Writer, reader telnet.Reader) { logger := ctx.Logger() @@ -111,23 +105,20 @@ func (telnetHandler *ShellHandler) ServeTELNET(ctx telnet.Context, writer telnet logger = internalDiscardLogger{} } - colonSpaceCommandNotFoundEL := []byte(": command not found\r\n") - - var prompt bytes.Buffer + var prompt bytes.Buffer var exitCommandName string - var welcomeMessage string - var exitMessage string + var welcomeMessage string + var exitMessage string prompt.WriteString(telnetHandler.Prompt) - promptBytes := prompt.Bytes() + promptBytes := prompt.Bytes() exitCommandName = telnetHandler.ExitCommandName - welcomeMessage = telnetHandler.WelcomeMessage - exitMessage = telnetHandler.ExitMessage - + welcomeMessage = telnetHandler.WelcomeMessage + exitMessage = telnetHandler.ExitMessage if _, err := oi.LongWriteString(writer, welcomeMessage); nil != err { logger.Errorf("Problem long writing welcome message: %v", err) @@ -140,158 +131,100 @@ func (telnetHandler *ShellHandler) ServeTELNET(ctx telnet.Context, writer telnet } logger.Debugf("Wrote prompt: %q.", promptBytes) + buffered := bufio.NewReader(reader) - var buffer [1]byte // Seems like the length of the buffer needs to be small, otherwise will have to wait for buffer to fill up. - p := buffer[:] - - var line bytes.Buffer + var err error + var line string - for { - // Read 1 byte. - n, err := reader.Read(p) - if n <= 0 && nil == err { - continue - } else if n <= 0 && nil != err { + for err == nil { + line, err = buffered.ReadString('\n') + if err != nil { break } + if "\r\n" == line { + _, err = oi.LongWrite(writer, promptBytes) + continue + } - line.WriteByte(p[0]) - //logger.Tracef("Received: %q (%d).", p[0], p[0]) - - - if '\n' == p[0] { - lineString := line.String() - - if "\r\n" == lineString { - line.Reset() - if _, err := oi.LongWrite(writer, promptBytes); nil != err { - return - } - continue - } + //@TODO: support piping. + fields := strings.Fields(line) + logger.Debugf("Have %d tokens.", len(fields)) + logger.Tracef("Tokens: %v", fields) + if len(fields) <= 0 { + _, err = oi.LongWrite(writer, promptBytes) + continue + } + field0 := fields[0] -//@TODO: support piping. - fields := strings.Fields(lineString) - logger.Debugf("Have %d tokens.", len(fields)) - logger.Tracef("Tokens: %v", fields) - if len(fields) <= 0 { - line.Reset() - if _, err := oi.LongWrite(writer, promptBytes); nil != err { - return - } - continue - } + if exitCommandName == field0 { + break + } + var producer Producer - field0 := fields[0] + telnetHandler.muxtex.RLock() + var ok bool + producer, ok = telnetHandler.producers[field0] + telnetHandler.muxtex.RUnlock() - if exitCommandName == field0 { - oi.LongWriteString(writer, exitMessage) - return - } + if !ok { + telnetHandler.muxtex.RLock() + producer = telnetHandler.elseProducer + telnetHandler.muxtex.RUnlock() + } + if nil == producer { + //@TODO: Don't convert that to []byte! think this creates "garbage" (for collector). + oi.LongWrite(writer, []byte(field0)) + oi.LongWrite(writer, colonSpaceCommandNotFoundEL) + _, err = oi.LongWrite(writer, promptBytes) + continue + } - var producer Producer + handler := producer.Produce(ctx, field0, fields[1:]...) + if nil == handler { + oi.LongWrite(writer, []byte(field0)) + //@TODO: Need to use a different error message. + oi.LongWrite(writer, colonSpaceCommandNotFoundEL) + _, err = oi.LongWrite(writer, promptBytes) + continue + } - telnetHandler.muxtex.RLock() - var ok bool - producer, ok = telnetHandler.producers[field0] - telnetHandler.muxtex.RUnlock() + //@TODO: Wire up the stdin, stdout, stderr of the handler. - if !ok { - telnetHandler.muxtex.RLock() - producer = telnetHandler.elseProducer - telnetHandler.muxtex.RUnlock() - } - - if nil == producer { -//@TODO: Don't convert that to []byte! think this creates "garbage" (for collector). - oi.LongWrite(writer, []byte(field0)) - oi.LongWrite(writer, colonSpaceCommandNotFoundEL) - line.Reset() - if _, err := oi.LongWrite(writer, promptBytes); nil != err { - return - } - continue - } - - handler := producer.Produce(ctx, field0, fields[1:]...) - if nil == handler { - oi.LongWrite(writer, []byte(field0)) -//@TODO: Need to use a different error message. - oi.LongWrite(writer, colonSpaceCommandNotFoundEL) - line.Reset() - oi.LongWrite(writer, promptBytes) - continue - } - -//@TODO: Wire up the stdin, stdout, stderr of the handler. - - if stdoutPipe, err := handler.StdoutPipe(); nil != err { -//@TODO: - } else if nil == stdoutPipe { -//@TODO: - } else { - connect(ctx, writer, stdoutPipe) - } - - - if stderrPipe, err := handler.StderrPipe(); nil != err { -//@TODO: - } else if nil == stderrPipe { -//@TODO: - } else { - connect(ctx, writer, stderrPipe) - } - - - if err := handler.Run(); nil != err { -//@TODO: - } - line.Reset() - if _, err := oi.LongWrite(writer, promptBytes); nil != err { - return - } + if stdoutPipe, err := handler.StdoutPipe(); nil != err { + //@TODO: + } else if nil == stdoutPipe { + //@TODO: + } else { + connect(ctx, writer, stdoutPipe) } + if stderrPipe, err := handler.StderrPipe(); nil != err { + //@TODO: + } else if nil == stderrPipe { + //@TODO: + } else { + connect(ctx, writer, stderrPipe) + } -//@TODO: Are there any special errors we should be dealing with separately? - if nil != err { - break + if err = handler.Run(); nil != err { + //@TODO: } + _, err = oi.LongWrite(writer, promptBytes) } - oi.LongWriteString(writer, exitMessage) return } - - func connect(ctx telnet.Context, writer io.Writer, reader io.Reader) { logger := ctx.Logger() - go func(logger telnet.Logger){ - - var buffer [1]byte // Seems like the length of the buffer needs to be small, otherwise will have to wait for buffer to fill up. - p := buffer[:] - - for { - // Read 1 byte. - n, err := reader.Read(p) - if n <= 0 && nil == err { - continue - } else if n <= 0 && nil != err { - break - } - - //logger.Tracef("Sending: %q.", p) -//@TODO: Should we be checking for errors? - oi.LongWrite(writer, p) - //logger.Tracef("Sent: %q.", p) - } + go func(logger telnet.Logger) { + io.Copy(writer, reader) }(logger) }