diff --git a/cmd/nbnslistener/nbnslistener.go b/cmd/nbnslistener/nbnslistener.go index fcae81c..7b731d4 100644 --- a/cmd/nbnslistener/nbnslistener.go +++ b/cmd/nbnslistener/nbnslistener.go @@ -7,17 +7,27 @@ import ( "net" "os" "strings" + "time" "github.com/irai/nbns" log "github.com/sirupsen/logrus" ) +var ( + netFlag = flag.String("cidr", "192.168.1.1/24", "network to broadcast nbns to") +) + func main() { flag.Parse() setLogLevel("info") - handler, err := nbns.NewHandler() + _, network, err := net.ParseCIDR(*netFlag) + if err != nil { + log.Fatal("invalid CIDR ", err) + } + + handler, err := nbns.NewHandler(*network) if err != nil { log.Fatal("error in nbns", err) } @@ -35,7 +45,7 @@ func main() { handler.AddNotificationChannel(notify) - go handler.ListenAndServe() + go handler.ListenAndServe(time.Minute * 1) cmd(handler) diff --git a/goroutine.go b/goroutine.go index d7f841c..d5a31ca 100644 --- a/goroutine.go +++ b/goroutine.go @@ -66,6 +66,13 @@ func (h *goroutinePool) Begin(name string) *goroutine { return &g } +func (h *goroutinePool) Stopping() bool { + if atomic.LoadInt32(&h.stopping) != 0 { + return true + } + return false +} + func (g *goroutine) End() { atomic.AddInt32(&g.pool.n, -1) stopping := atomic.LoadInt32(&g.pool.stopping) diff --git a/nbns.go b/nbns.go index c093c28..b9e097b 100644 --- a/nbns.go +++ b/nbns.go @@ -1,20 +1,19 @@ package nbns import ( - "bytes" - "encoding/binary" "fmt" "net" - "strings" "syscall" + "time" log "github.com/sirupsen/logrus" ) // Handler create a new NBNS handler type Handler struct { - conn *net.UDPConn - notification chan<- Entry + conn *net.UDPConn + broadcastAddr net.IP + notification chan<- Entry } // Entry holds a NBNS name entry @@ -24,7 +23,7 @@ type Entry struct { } // NewHandler create a NBNS handler -func NewHandler() (handler *Handler, err error) { +func NewHandler(network net.IPNet) (handler *Handler, err error) { // srcAddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:0") handler = &Handler{} @@ -32,33 +31,21 @@ func NewHandler() (handler *Handler, err error) { log.Error("NBNS failed to bind UDP port 137 ", err) return nil, err } + + // calculate broadcast addr + handler.broadcastAddr = net.IP(make([]byte, 4)) + for i := range network.IP { + handler.broadcastAddr[i] = network.IP[i] | ^network.Mask[i] + } return handler, nil } -// SendQuery send a NBNS query -// 4.2.12. NAME QUERY REQUEST -// -// 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | NAME_TRN_ID |0| 0x0 |0|0|1|0|0 0|B| 0x0 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | 0x0001 | 0x0000 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | 0x0000 | 0x0000 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | | -// / QUESTION_NAME / -// / / -// | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | NB (0x0020) | IN (0x0001) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// SendQuery send NBNS node status request query func (h *Handler) SendQuery(ip net.IP) (err error) { packet := nodeStatusRequestWireFormat(`* `) - packet.printHeader() + // packet.printHeader() if ip == nil || ip.Equal(net.IPv4zero) { return fmt.Errorf("invalid IP nil %v", ip) @@ -68,7 +55,9 @@ func (h *Handler) SendQuery(ip net.IP) (err error) { // To broadcast, use network broadcast i.e 192.168.0.255 for example. targetAddr := &net.UDPAddr{IP: ip, Port: 137} if _, err = h.conn.WriteToUDP(packet, targetAddr); err != nil { - log.Error("NPNS failed to send nbns packet ", err) + if !GoroutinePool.Stopping() { + log.Error("NPNS failed to send nbns packet ", err) + } return err } return nil @@ -85,85 +74,31 @@ func (h *Handler) Stop() { GoroutinePool.Stop() // will stop all goroutines } -// processNodeStatusResponse -// 4.2.18. NODE STATUS RESPONSE -// 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | Header | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// / RR_NAME (variable len) / -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | NBSTAT (0x0021) | IN (0x0001) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | 0x00000000 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | RDLENGTH | NUM_NAMES | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + -// / NODE_NAME ARRAY (variable len) / -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// / STATISTICS (variable len) / -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -func (h *Handler) processNameQueryResponse(packet packet, ip net.IP) error { - - log.Debug("nbns received name query packet") - packet.printHeader() - // Assume no questions - if packet.qdCount() > 0 { - log.Printf("unexpected qdcount %v ", packet.qdCount()) - } +func (h *Handler) broadcastLoop(interval time.Duration) { + g := GoroutinePool.Begin("nbns broadcastLoop") + defer g.End() - // Assume a single Response name - if packet.anCount() <= 0 { - return fmt.Errorf("unexpected ancount %v ", packet.anCount()) - } + for { + h.SendQuery(h.broadcastAddr) + select { + case <-GoroutinePool.StopChannel: + return - // variable len reading - buf := bytes.NewBuffer(packet.payload()) - name := decodeNBNSName(buf) - log.Printf("name: |%s|\n", name) - - var tmp16 uint16 - var numNames uint8 - binary.Read(buf, binary.BigEndian, &tmp16) // type - binary.Read(buf, binary.BigEndian, &tmp16) // internet - binary.Read(buf, binary.BigEndian, &tmp16) // TTL is 32 bits - binary.Read(buf, binary.BigEndian, &tmp16) // TTL - binary.Read(buf, binary.BigEndian, &tmp16) // RDLength - binary.Read(buf, binary.BigEndian, &numNames) // numNames - - tmpName := make([]byte, 16) - table := []string{} - for i := 0; i < int(numNames); i++ { - binary.Read(buf, binary.BigEndian, &tmpName) - binary.Read(buf, binary.BigEndian, &tmp16) // nameFlags - // log.Infof("names %q nFlags %02x", tmpName, tmp16) - if (tmp16 & 0x8000) == 0x00 { // don't add to the table if this is group name - t := strings.TrimRight(string(tmpName), " \x00") - t = strings.TrimRight(t, " \x03") // not sure why some have 03 at the end - t = strings.TrimRight(t, " \x1d") // not sure why some have 1d at the end - table = append(table, t) + case <-time.After(interval): } } - entry := Entry{IP: ip, Name: table[0]} // first entry - log.Info("nodes ", entry.Name, entry.IP, table) - if h.notification != nil { - log.Debugf("nbns send notification name %s ip %s", entry.Name, entry.IP) - h.notification <- entry - } - - return nil } // ListenAndServe main listening loop -func (h *Handler) ListenAndServe() error { +func (h *Handler) ListenAndServe(interval time.Duration) error { g := GoroutinePool.Begin("nbns ListenAndServe") defer g.End() + go h.broadcastLoop(interval) + readBuffer := make([]byte, 1024) - // h.conn.SetDeadline(time.Now().Add(200 * time.Millisecond)) for !g.Stopping() { _, udpAddr, err := h.conn.ReadFromUDP(readBuffer) if g.Stopping() { @@ -197,16 +132,29 @@ func (h *Handler) ListenAndServe() error { return err } - log.Info("nbns received nbns packet from IP ", *udpAddr) packet := packet(readBuffer) switch { case packet.opcode() == opcodeQuery && packet.response() == 1: - if err := h.processNameQueryResponse(packet, udpAddr.IP); err != nil { - log.Error(err) + log.Info("nbns received nbns nodeStatusResponse from IP ", *udpAddr) + entry, err := parseNodeStatusResponsePacket(packet, udpAddr.IP) + if err != nil { + log.Error("error processing nodeStatusResponse ", err) + return err + } + + if h.notification != nil { + log.Debugf("nbns send notification name %s ip %s", entry.Name, entry.IP) + h.notification <- entry + } + + case packet.response() == 0: + if packet.trnID() != sequence { // ignore our own request + log.Info("nbns not implemented - recvd nbns request from IP ", *udpAddr) + packet.printHeader() } default: - log.Infof("nbns packet opcode=%v not implemented ", packet.opcode()) + log.Infof("nbns not implemented opcode=%v ", packet.opcode()) packet.printHeader() } diff --git a/packet.go b/packet.go index 2e99b8c..13b6722 100644 --- a/packet.go +++ b/packet.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "net" "strings" log "github.com/sirupsen/logrus" @@ -57,17 +58,15 @@ const ( questionClassInternet = 0x0001 ) -var sequence uint16 = 1 +var sequence uint16 = 1 // incremented for every packet sent // encodeNBNSName creates a 34 byte string = 1 char length + 32 char netbios name + 1 final length (0x00). -// Netbios names are 16 bytes long = 15 characters plus space(0x20) -// -// +// Netbios names are 16 bytes long = 16 characters (two bytes per character) func encodeNBNSName(name string) string { - // Netbios name is limited to 15 characters long + // Netbios name is limited to 16 characters long if len(name) > netbiosMaxNameLen { - name = name[:netbiosMaxNameLen] // truncate if name too long + name = name[:netbiosMaxNameLen-1] // truncate if name too long } if len(name) < netbiosMaxNameLen { @@ -102,8 +101,6 @@ func decodeNBNSName(buf *bytes.Buffer) (name string) { return "" } - // fmt.Printf("RR name len %v name % x \n", length, name) - // A label length count is actually a 6-bit field in the label length // field. The most significant 2 bits of the field, bits 7 and 6, are // flags allowing an escape from the above compressed representation. @@ -120,8 +117,7 @@ func decodeNBNSName(buf *bytes.Buffer) (name string) { return "" } - // 0 is len; name starts at 1 - tmp = tmp[1:] + tmp = tmp[1:] // 0 is len; name starts at 1 for i := 0; i < 32; i = i + 2 { character := ((tmp[i] - 'A') << 4) | (tmp[i+1] - 'A') name = name + string(character) @@ -135,7 +131,6 @@ func decodeNBNSName(buf *bytes.Buffer) (name string) { // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ // | R | Opcode |AA |TC |RD |RA | 0 | 0 | B | Rcode | // +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ -// Work in progress - TBD type packet []byte func (p packet) response() byte { return p[2] >> 7 } @@ -174,7 +169,20 @@ func (p packet) printHeader() { fmt.Printf("NSCount 0x%04x | ARCount 0x%04x\n", p.nsCount(), p.arCount()) } -// Node Status Request - Packet layout must be packed with bigendian +// nameQueryWireFormat Name Query Request +func nameQueryWireFormat(name string) (packet packet) { + return query(name, questionTypeGeneral) +} + +// same as name query but type 21 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NBSTAT (0x0021) | IN (0x0001) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +func nodeStatusRequestWireFormat(name string) (packet packet) { + return query(name, questionTypeNodeStatus) +} + +// query Packet layout must be packed with bigendian // // 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -191,18 +199,6 @@ func (p packet) printHeader() { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | NBSTAT (0x0020) | IN (0x0001) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -func nameQueryWireFormat(name string) (packet packet) { - return query(name, questionTypeGeneral) -} - -// same as name query but type 21 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | NBSTAT (0x0021) | IN (0x0001) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -func nodeStatusRequestWireFormat(name string) (packet packet) { - return query(name, questionTypeNodeStatus) -} - func query(name string, questionType uint16) (packet packet) { const word = uint16(responseRequest | opcodeQuery | nmflagsBroadcast | rcodeOK) @@ -221,3 +217,70 @@ func query(name string, questionType uint16) (packet packet) { binary.Write(buf, binary.BigEndian, uint16(questionClassInternet)) // QClass return buf.Bytes() } + +// parseNodeStatusResponse +// 4.2.18. NODE STATUS RESPONSE +// 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Header | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / RR_NAME (variable len) / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NBSTAT (0x0021) | IN (0x0001) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | 0x00000000 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | RDLENGTH | NUM_NAMES | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +// / NODE_NAME ARRAY (variable len) / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / STATISTICS (variable len) / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +func parseNodeStatusResponsePacket(packet packet, ip net.IP) (entry Entry, err error) { + log.Debug("nbns parsing node status response packet") + + // Assume no questions + if packet.qdCount() > 0 { + log.Errorf("nbns unexpected qdcount %v ", packet.qdCount()) + } + + // Assume a single Response name + if packet.anCount() <= 0 { + return Entry{}, fmt.Errorf("unexpected ancount %v ", packet.anCount()) + } + + // variable len reading + buf := bytes.NewBuffer(packet.payload()) + name := decodeNBNSName(buf) + log.Debugf("nbns name: |%s|\n", name) + + var tmp16 uint16 + var numNames uint8 + binary.Read(buf, binary.BigEndian, &tmp16) // type + binary.Read(buf, binary.BigEndian, &tmp16) // internet + binary.Read(buf, binary.BigEndian, &tmp16) // TTL is 32 bits + binary.Read(buf, binary.BigEndian, &tmp16) // TTL + binary.Read(buf, binary.BigEndian, &tmp16) // RDLength + binary.Read(buf, binary.BigEndian, &numNames) // numNames + + tmpName := make([]byte, 16) + table := []string{} + for i := 0; i < int(numNames); i++ { + binary.Read(buf, binary.BigEndian, &tmpName) + binary.Read(buf, binary.BigEndian, &tmp16) // nameFlags + // log.Infof("names %q nFlags %02x", tmpName, tmp16) + if (tmp16 & 0x8000) == 0x00 { // don't add to the table if this is group name + t := strings.TrimRight(string(tmpName), " \x00") + t = strings.TrimRight(t, " \x03") // not sure why some have 03 at the end + t = strings.TrimRight(t, " \x1d") // not sure why some have 1d at the end + table = append(table, t) + } + } + + entry = Entry{IP: ip, Name: table[0]} // first entry + log.Debugf("nbns new entry name %s ip %s", entry.Name, entry.IP) + log.Debug("nbns node names ", table) + + return entry, nil +}