diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 0000000..99adc62 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,9 @@ +version = 1 + +[[analyzers]] +name = "go" +enabled = true + + [analyzers.meta] + import_root = "github.com/Sandertv/go-raknet" + dependencies_vendored = true diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..a8caa38 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,34 @@ +name: Go +on: [push] +jobs: + + build: + name: Build + runs-on: ubuntu-latest + steps: + + - name: Set up Go 1.19 + uses: actions/setup-go@v1 + with: + go-version: 1.19 + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v1 + + - name: Get dependencies + run: | + mkdir -p $GOPATH/bin + export PATH=$PATH:$GOPATH/bin + + - name: Vet + run: go vet ./... + + - name: Staticcheck + run: | + go get honnef.co/go/tools/cmd/staticcheck + GOBIN=$PWD/bin go install honnef.co/go/tools/cmd/staticcheck + ./bin/staticcheck ./... + + - name: Build + run: go build -o raknet_exe -v . \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..703d8e4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,71 @@ +# Created by .ignore support plugin (hsz.mobi) +### Go template +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +idea/ + +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/dictionaries +.idea/**/shelf + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-debug/ +cmake-build-release/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +.vscode \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e7a10c9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Sander ten Veldhuis + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a0c22df --- /dev/null +++ b/README.md @@ -0,0 +1,65 @@ +# go-raknet + +go-raknet is a library that implements a basic version of the RakNet protocol, which is used for +Minecraft (Bedrock Edition). It implements Unreliable, Reliable and +ReliableOrdered packets and sends user packets as ReliableOrdered. + +go-raknet attempts to abstract away direct interaction with RakNet, and provides simple to use, idiomatic Go +API used to listen for connections or connect to servers. + +## Getting started + +### Prerequisites +To use this library, at least **Go 1.18** must be installed. + +### Usage +go-raknet can be used for both clients and servers, (and proxies, when combined) in a way very similar to the +standard net.TCP* functions. + +Basic RakNet server: +```go +package main + +import ( + "github.com/sandertv/go-raknet" +) + +func main() { + listener, _ := raknet.Listen("0.0.0.0:19132") + defer listener.Close() + for { + conn, _ := listener.Accept() + + b := make([]byte, 1024*1024*4) + _, _ = conn.Read(b) + _, _ = conn.Write([]byte{1, 2, 3}) + + conn.Close() + } +} +``` + +Basic RakNet client: + +```go +package main + +import ( + "github.com/sandertv/go-raknet" +) + +func main() { + conn, _ := raknet.Dial("mco.mineplex.com:19132") + defer conn.Close() + + b := make([]byte, 1024*1024*4) + _, _ = conn.Write([]byte{1, 2, 3}) + _, _ = conn.Read(b) +} +``` + +### Documentation +[![PkgGoDev](https://pkg.go.dev/badge/github.com/sandertv/go-raknet)](https://pkg.go.dev/github.com/sandertv/go-raknet) + +## Contact +[![Discord Banner 2](https://discordapp.com/api/guilds/623638955262345216/widget.png?style=banner2)](https://discord.gg/U4kFWHhTNR) diff --git a/binary.go b/binary.go new file mode 100644 index 0000000..9b60f01 --- /dev/null +++ b/binary.go @@ -0,0 +1,29 @@ +package raknet + +import ( + "bytes" + "fmt" +) + +// uint24 represents an integer existing out of 3 bytes. It is actually a uint32, but is an alias for the +// sake of clarity. +type uint24 uint32 + +// readUint24 reads 3 bytes from the buffer passed and combines it into a uint24. If there were no 3 bytes to +// read, an error is returned. +func readUint24(b *bytes.Buffer) (uint24, error) { + ba, _ := b.ReadByte() + bb, _ := b.ReadByte() + bc, err := b.ReadByte() + if err != nil { + return 0, fmt.Errorf("error reading uint24: %v", err) + } + return uint24(ba) | (uint24(bb) << 8) | (uint24(bc) << 16), nil +} + +// writeUint24 writes a uint24 to the buffer passed as 3 bytes. If not successful, an error is returned. +func writeUint24(b *bytes.Buffer, value uint24) { + b.WriteByte(byte(value)) + b.WriteByte(byte(value >> 8)) + b.WriteByte(byte(value >> 16)) +} diff --git a/binary_test.go b/binary_test.go new file mode 100644 index 0000000..87d7a4a --- /dev/null +++ b/binary_test.go @@ -0,0 +1,18 @@ +package raknet + +import ( + "bytes" + "testing" +) + +func Test_uint24(t *testing.T) { + b := bytes.NewBuffer(nil) + writeUint24(b, 123456) + val, err := readUint24(b) + if err != nil { + t.Fatalf("error reading uint24: %v", err) + } + if val != 123456 { + t.Fatal("read uint24 was not equal to 123456") + } +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..43e9aaf --- /dev/null +++ b/conn.go @@ -0,0 +1,776 @@ +package raknet + +import ( + "bytes" + "context" + "fmt" + "github.com/df-mc/atomic" + "github.com/sandertv/go-raknet/internal/message" + "net" + "sync" + "time" +) + +const ( + // currentProtocol is the current RakNet protocol version. This is Minecraft specific. + currentProtocol byte = 11 + + maxMTUSize = 1400 + maxWindowSize = 2048 +) + +// Conn represents a connection to a specific client. It is not a real connection, as UDP is connectionless, +// but rather a connection emulated using RakNet. +// Methods may be called on Conn from multiple goroutines simultaneously. +type Conn struct { + // rtt is the last measured round-trip time between both ends of the connection. The rtt is measured in nanoseconds. + rtt atomic.Int64 + + closing atomic.Int64 + + conn net.PacketConn + addr net.Addr + limits bool + + once sync.Once + closed, connected chan struct{} + close func() + + mu sync.Mutex + buf *bytes.Buffer + + ackBuf, nackBuf *bytes.Buffer + + pk *packet + + seq, orderIndex, messageIndex uint24 + splitID uint32 + + // mtuSize is the MTU size of the connection. Packets longer than this size must be split into fragments + // for them to arrive at the client without losing bytes. + mtuSize uint16 + + // splits is a map of slices indexed by split IDs. The length of each of the slices is equal to the split + // count, and packets are positioned in that slice indexed by the split index. + splits map[uint16][][]byte + + // win is an ordered queue used to track which datagrams were received and which datagrams + // were missing, so that we can send NACKs to request missing datagrams. + win *datagramWindow + + ackMu sync.Mutex + // ackSlice is a slice containing sequence numbers of datagrams that were received over the last + // second. When ticked, all of these packets are sent in an ACK and the slice is cleared. + ackSlice []uint24 + + // packetQueue is an ordered queue containing packets indexed by their order index. + packetQueue *packetQueue + // packets is a channel containing content of packets that were fully processed. Calling Conn.Read() + // consumes a value from this channel. + packets chan *bytes.Buffer + + // retransmission is a queue filled with packets that were sent with a given datagram sequence number. + retransmission *resendMap + + // readDeadline is a channel that receives a time.Time after a specific time. It is used to listen for + // timeouts in Read after calling SetReadDeadline. + readDeadline <-chan time.Time + + lastActivity atomic.Value[time.Time] +} + +// newConn constructs a new connection specifically dedicated to the address passed. +func newConn(conn net.PacketConn, addr net.Addr, mtuSize uint16) *Conn { + return newConnWithLimits(conn, addr, mtuSize, true) +} + +// newConnWithLimits returns a Conn for the net.Addr passed with a specific mtu size. The limits bool passed specifies +// if the connection should limit the bounds of things such as the size of packets. This is generally recommended for +// connections coming from a client. +func newConnWithLimits(conn net.PacketConn, addr net.Addr, mtuSize uint16, limits bool) *Conn { + if mtuSize < 500 || mtuSize > 1500 { + mtuSize = maxMTUSize + } + c := &Conn{ + addr: addr, + conn: conn, + limits: limits, + mtuSize: mtuSize, + pk: new(packet), + closed: make(chan struct{}), + connected: make(chan struct{}), + packets: make(chan *bytes.Buffer, 512), + splits: make(map[uint16][][]byte), + win: newDatagramWindow(), + packetQueue: newPacketQueue(), + retransmission: newRecoveryQueue(), + buf: bytes.NewBuffer(make([]byte, 0, mtuSize)), + ackBuf: bytes.NewBuffer(make([]byte, 0, 256)), + nackBuf: bytes.NewBuffer(make([]byte, 0, 256)), + lastActivity: *atomic.NewValue(time.Now()), + } + go c.startTicking() + return c +} + +// startTicking makes the connection start ticking, sending ACKs and pings to the other end where necessary +// and checking if the connection should be timed out. +func (conn *Conn) startTicking() { + var ( + interval = time.Second / 10 + ticker = time.NewTicker(interval) + i int64 + acksLeft int + ) + defer ticker.Stop() + for { + select { + case t := <-ticker.C: + i++ + conn.flushACKs() + if i%2 == 0 { + // We send a connected ping to calculate the rtt and let the other side know we haven't + // timed out. + conn.sendPing() + } + if i%3 == 0 { + conn.checkResend(t) + } + if i%5 == 0 { + conn.mu.Lock() + if t.Sub(conn.lastActivity.Load()) > time.Second*5+conn.retransmission.rtt()*2 { + // No activity for too long: Start timeout. + _ = conn.Close() + } + conn.mu.Unlock() + } + if unix := conn.closing.Load(); unix != 0 { + before := acksLeft + conn.mu.Lock() + acksLeft = len(conn.retransmission.unacknowledged) + conn.mu.Unlock() + + if before != 0 && acksLeft == 0 { + _ = conn.Close() + } + + since := time.Since(time.Unix(unix, 0)) + if (acksLeft == 0 && since > time.Second) || since > time.Second*8 { + conn.closeImmediately() + } + } + case <-conn.closed: + return + } + } +} + +// flushACKs flushes all pending datagram acknowledgements. +func (conn *Conn) flushACKs() { + conn.ackMu.Lock() + defer conn.ackMu.Unlock() + + if len(conn.ackSlice) > 0 { + // Write an ACK packet to the connection containing all datagram sequence numbers that we + // received since the last tick. + if err := conn.sendACK(conn.ackSlice...); err != nil { + return + } + conn.ackSlice = conn.ackSlice[:0] + } +} + +// checkResend checks if the connection needs to resend any packets. It sends an ACK for packets it has +// received and sends any packets that have been pending for too long. +func (conn *Conn) checkResend(now time.Time) { + conn.mu.Lock() + defer conn.mu.Unlock() + + var ( + resend []uint24 + rtt = conn.retransmission.rtt() + delay = rtt + rtt/2 + ) + conn.rtt.Store(int64(rtt)) + + for seq, t := range conn.retransmission.unacknowledged { + // These packets have not been acknowledged for too long: We resend them by ourselves, even though no + // NACK has been issued yet. + if now.Sub(t.timestamp) > delay { + resend = append(resend, seq) + } + } + _ = conn.resend(resend) +} + +// Write writes a buffer b over the RakNet connection. The amount of bytes written n is always equal to the +// length of the bytes written if the write was successful. If not, an error is returned and n is 0. +// Write may be called simultaneously from multiple goroutines, but will write one by one. +func (conn *Conn) Write(b []byte) (n int, err error) { + select { + case <-conn.closed: + return 0, conn.wrap(net.ErrClosed, "write") + default: + conn.mu.Lock() + defer conn.mu.Unlock() + n, err := conn.write(b) + return n, conn.wrap(err, "write") + } +} + +// write writes a buffer b over the RakNet connection. The amount of bytes written n is always equal to the +// length of the bytes written if the write was successful. If not, an error is returned and n is 0. +// Write may be called simultaneously from multiple goroutines, but will write one by one. +// Unlike Write, write will not lock. +func (conn *Conn) write(b []byte) (n int, err error) { + fragments := conn.split(b) + orderIndex := conn.orderIndex + conn.orderIndex++ + + splitID := uint16(conn.splitID) + split := len(fragments) > 1 + if split { + conn.splitID++ + } + for splitIndex, content := range fragments { + sequenceNumber := conn.seq + conn.seq++ + messageIndex := conn.messageIndex + conn.messageIndex++ + + conn.buf.WriteByte(bitFlagDatagram | bitFlagNeedsBAndAS) + writeUint24(conn.buf, sequenceNumber) + pk := packetPool.Get().(*packet) + if cap(pk.content) < len(content) { + pk.content = make([]byte, len(content)) + } + // We set the actual slice size to the same size as the content. It might be bigger than the previous + // size, in which case it will grow, which is fine as the underlying array will always be big enough. + pk.content = pk.content[:len(content)] + copy(pk.content, content) + + pk.orderIndex = orderIndex + pk.messageIndex = messageIndex + + pk.split = split + if split { + // If there were more than one fragment, the pk was split, so we need to make sure we set the + // appropriate fields. + pk.splitCount = uint32(len(fragments)) + pk.splitIndex = uint32(splitIndex) + pk.splitID = splitID + } + pk.write(conn.buf) + // We then send the pk to the connection. + if _, err := conn.conn.WriteTo(conn.buf.Bytes(), conn.addr); err != nil { + return 0, net.ErrClosed + } + + // We reset the buffer so that we can re-use it for each fragment created when splitting the pk. + conn.buf.Reset() + + // Finally we add the pk to the recovery queue. + conn.retransmission.add(sequenceNumber, pk) + n += len(content) + } + return +} + +// Read reads from the connection into the byte slice passed. If successful, the amount of bytes read n is +// returned, and the error returned will be nil. +// Read blocks until a packet is received over the connection, or until the session is closed or the read +// times out, in which case an error is returned. +func (conn *Conn) Read(b []byte) (n int, err error) { + select { + case pk := <-conn.packets: + if len(b) < pk.Len() { + err = conn.wrap(errBufferTooSmall, "read") + } + return copy(b, pk.Bytes()), err + case <-conn.closed: + return 0, conn.wrap(net.ErrClosed, "read") + case <-conn.readDeadline: + return 0, conn.wrap(context.DeadlineExceeded, "read") + } +} + +// ReadPacket attempts to read the next packet as a byte slice. +// ReadPacket blocks until a packet is received over the connection, or until the session is closed or the +// read times out, in which case an error is returned. +func (conn *Conn) ReadPacket() (b []byte, err error) { + select { + case packet := <-conn.packets: + return packet.Bytes(), err + case <-conn.closed: + return nil, conn.wrap(net.ErrClosed, "read") + case <-conn.readDeadline: + return nil, conn.wrap(context.DeadlineExceeded, "read") + } +} + +// Close closes the connection. All blocking Read or Write actions are cancelled and will return an error, as +// soon as the closing of the connection is acknowledged by the client. +func (conn *Conn) Close() error { + conn.closing.CAS(0, time.Now().Unix()) + return nil +} + +// closeImmediately sends a Disconnect notification to the other end of the connection and +// closes the underlying UDP connection immediately. +func (conn *Conn) closeImmediately() { + conn.once.Do(func() { + _, _ = conn.Write([]byte{message.IDDisconnectNotification}) + close(conn.closed) + if conn.close != nil { + conn.close() + conn.close = nil + } + }) +} + +// RemoteAddr returns the remote address of the connection, meaning the address this connection leads to. +func (conn *Conn) RemoteAddr() net.Addr { + return conn.addr +} + +// LocalAddr returns the local address of the connection, which is always the same as the listener's. +func (conn *Conn) LocalAddr() net.Addr { + return conn.conn.LocalAddr() +} + +// SetReadDeadline sets the read deadline of the connection. An error is returned only if the time passed is +// before time.Now(). +// Calling SetReadDeadline means the next Read call that exceeds the deadline will fail and return an error. +// Setting the read deadline to the default value of time.Time removes the deadline. +func (conn *Conn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + conn.readDeadline = make(chan time.Time) + return nil + } + if t.Before(time.Now()) { + panic(fmt.Errorf("read deadline cannot be before now")) + } + conn.readDeadline = time.After(time.Until(t)) + return nil +} + +// SetWriteDeadline has no behaviour. It is merely there to satisfy the net.Conn interface. +func (conn *Conn) SetWriteDeadline(time.Time) error { + return nil +} + +// SetDeadline sets the deadline of the connection for both Read and Write. SetDeadline is equivalent to +// calling both SetReadDeadline and SetWriteDeadline. +func (conn *Conn) SetDeadline(t time.Time) error { + return conn.SetReadDeadline(t) +} + +// Latency returns a rolling average of rtt between the sending and the receiving end of the connection. +// The rtt returned is updated continuously and is half the average round trip time (RTT). +func (conn *Conn) Latency() time.Duration { + return time.Duration(conn.rtt.Load() / 2) +} + +// sendPing pings the connection, updating the rtt of the Conn if successful. +func (conn *Conn) sendPing() { + b := bytes.NewBuffer(nil) + (&message.ConnectedPing{ClientTimestamp: timestamp()}).Write(b) + _, _ = conn.Write(b.Bytes()) +} + +// packetPool is a sync.Pool used to pool packets that encapsulate their content. +var packetPool = sync.Pool{ + New: func() interface{} { + return &packet{reliability: reliabilityReliableOrdered} + }, +} + +const ( + // Datagram header + + // Datagram sequence number + + // Packet header + + // Packet content length + + // Packet message index + + // Packet order index + + // Packet order channel + packetAdditionalSize = 1 + 3 + 1 + 2 + 3 + 3 + 1 + // Packet split count + + // Packet split ID + + // Packet split index + splitAdditionalSize = 4 + 2 + 4 +) + +// split splits a content buffer in smaller buffers so that they do not exceed the MTU size that the +// connection holds. +func (conn *Conn) split(b []byte) [][]byte { + maxSize := int(conn.mtuSize-packetAdditionalSize) - 28 + contentLength := len(b) + if contentLength > maxSize { + // If the content size is bigger than the maximum size here, it means the packet will get split. This + // means that the packet will get even bigger because a split packet uses 4 + 2 + 4 more bytes. + maxSize -= splitAdditionalSize + } + fragmentCount := contentLength / maxSize + if contentLength%maxSize != 0 { + // If the content length can't be divided by maxSize perfectly, we need to reserve another fragment + // for the last bit of the packet. + fragmentCount++ + } + fragments := make([][]byte, fragmentCount) + + buf := bytes.NewBuffer(b) + for i := 0; i < fragmentCount; i++ { + // Take a piece out of the content with the size of maxSize. + fragments[i] = buf.Next(maxSize) + } + return fragments +} + +// receive receives a packet from the connection, handling it as appropriate. If not successful, an error is +// returned. +func (conn *Conn) receive(b *bytes.Buffer) error { + headerFlags, err := b.ReadByte() + if err != nil { + return fmt.Errorf("error reading datagram header flags: %v", err) + } + if headerFlags&bitFlagDatagram == 0 { + // Ignore packets that do not have the datagram bitflag. + return nil + } + conn.lastActivity.Store(time.Now()) + switch { + case headerFlags&bitFlagACK != 0: + return conn.handleACK(b) + case headerFlags&bitFlagNACK != 0: + return conn.handleNACK(b) + default: + return conn.receiveDatagram(b) + } +} + +// receiveDatagram handles the receiving of a datagram found in buffer b. If successful, all packets inside +// the datagram are handled. if not, an error is returned. +func (conn *Conn) receiveDatagram(b *bytes.Buffer) error { + seq, err := readUint24(b) + if err != nil { + return fmt.Errorf("error reading datagram sequence number: %v", err) + } + conn.ackMu.Lock() + // Add this sequence number to the received datagrams, so that it is included in an ACK. + conn.ackSlice = append(conn.ackSlice, seq) + conn.ackMu.Unlock() + + if !conn.win.new(seq) { + // Datagram was already received, this might happen if a packet took a long time to arrive, and we already sent + // a NACK for it. This is expected to happen sometimes under normal circumstances, so no reason to return an + // error. + return nil + } + conn.win.add(seq) + if conn.win.shift() == 0 { + // Datagram window couldn't be shifted up, so we're still missing packets. + rtt := time.Duration(conn.rtt.Load()) + if missing := conn.win.missing(rtt + rtt/2); len(missing) > 0 { + if err = conn.sendNACK(missing); err != nil { + return fmt.Errorf("error sending NACK to request datagrams: %v", err) + } + } + } + if conn.win.size() > maxWindowSize && conn.limits { + return fmt.Errorf("datagram receive queue window size is too big (%v-%v)", conn.win.lowest, conn.win.highest) + } + return conn.handleDatagram(b) +} + +// handleDatagram handles the contents of a datagram encoded in a bytes.Buffer. +func (conn *Conn) handleDatagram(b *bytes.Buffer) error { + for b.Len() > 0 { + if err := conn.pk.read(b); err != nil { + return fmt.Errorf("error decoding datagram packet: %v", err) + } + handle := conn.receivePacket + if conn.pk.split { + handle = conn.receiveSplitPacket + } + if err := handle(conn.pk); err != nil { + return fmt.Errorf("error handling packet in datagram: %v", err) + } + } + return nil +} + +// receivePacket handles the receiving of a packet. It puts the packet in the queue and takes out all packets +// that were obtainable after that, and handles them. +func (conn *Conn) receivePacket(packet *packet) error { + if packet.reliability != reliabilityReliableOrdered { + // If it isn't a reliable ordered packet, handle it immediately. + return conn.handlePacket(packet.content) + } + if !conn.packetQueue.put(packet.orderIndex, packet.content) { + // An ordered packet arrived twice. + return nil + } + if conn.packetQueue.WindowSize() > maxWindowSize && conn.limits { + return fmt.Errorf("packet queue window size is too big (%v-%v)", conn.packetQueue.lowest, conn.packetQueue.highest) + } + for _, content := range conn.packetQueue.fetch() { + if err := conn.handlePacket(content); err != nil { + return fmt.Errorf("error handling packet: %v", err) + } + } + return nil +} + +// handlePacket handles a packet serialised in byte slice b. If not successful, an error is returned. If the +// packet was not handled by RakNet, it is sent to the packet channel. +func (conn *Conn) handlePacket(b []byte) error { + buffer := bytes.NewBuffer(b) + id, err := buffer.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet ID: %v", err) + } + + switch id { + case message.IDConnectionRequest: + return conn.handleConnectionRequest(buffer) + case message.IDConnectionRequestAccepted: + return conn.handleConnectionRequestAccepted(buffer) + case message.IDNewIncomingConnection: + select { + case <-conn.connected: + default: + close(conn.connected) + } + case message.IDConnectedPing: + return conn.handleConnectedPing(buffer) + case message.IDConnectedPong: + return conn.handleConnectedPong(buffer) + case message.IDDisconnectNotification: + conn.closeImmediately() + case message.IDDetectLostConnections: + // Let the other end know the connection is still alive. + conn.sendPing() + default: + _ = buffer.UnreadByte() + // Insert the packet contents the packet queue could release in the channel so that Conn.Read() can + // get a hold of them, but always first try to escape if the connection was closed. + select { + case <-conn.closed: + case conn.packets <- buffer: + } + } + return nil +} + +// handleConnectedPing handles a connected ping packet inside of buffer b. An error is returned if the packet +// was invalid. +func (conn *Conn) handleConnectedPing(b *bytes.Buffer) error { + packet := &message.ConnectedPing{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading connected ping: %v", err) + } + b.Reset() + + // Respond with a connected pong that has the ping timestamp found in the connected ping, and our own + // timestamp for the pong timestamp. + (&message.ConnectedPong{ClientTimestamp: packet.ClientTimestamp, ServerTimestamp: timestamp()}).Write(b) + _, err := conn.Write(b.Bytes()) + return err +} + +// handleConnectedPong handles a connected pong packet inside of buffer b. An error is returned if the packet +// was invalid. +func (conn *Conn) handleConnectedPong(b *bytes.Buffer) error { + packet := &message.ConnectedPong{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading connected pong: %v", err) + } + if packet.ClientTimestamp > timestamp() { + return fmt.Errorf("error measuring rtt: ping timestamp is in the future") + } + // We don't actually use the ConnectedPong to measure rtt. It is too unreliable and doesn't give a + // good idea of the connection quality. + return nil +} + +// handleConnectionRequest handles a connection request packet inside of buffer b. An error is returned if the +// packet was invalid. +func (conn *Conn) handleConnectionRequest(b *bytes.Buffer) error { + packet := &message.ConnectionRequest{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading connection request: %v", err) + } + b.Reset() + (&message.ConnectionRequestAccepted{ClientAddress: *conn.addr.(*net.UDPAddr), RequestTimestamp: packet.RequestTimestamp, AcceptedTimestamp: timestamp()}).Write(b) + _, err := conn.Write(b.Bytes()) + return err +} + +// handleConnectionRequestAccepted handles a serialised connection request accepted packet in b, and returns +// an error if not successful. +func (conn *Conn) handleConnectionRequestAccepted(b *bytes.Buffer) error { + packet := &message.ConnectionRequestAccepted{} + _ = packet.Read(b) + b.Reset() + + (&message.NewIncomingConnection{ServerAddress: *conn.addr.(*net.UDPAddr), RequestTimestamp: packet.RequestTimestamp, AcceptedTimestamp: packet.AcceptedTimestamp, SystemAddresses: packet.SystemAddresses}).Write(b) + _, err := conn.Write(b.Bytes()) + + select { + case <-conn.connected: + default: + close(conn.connected) + } + return err +} + +// receiveSplitPacket handles a passed split packet. If it is the last split packet of its sequence, it will +// continue handling the full packet as it otherwise would. +// An error is returned if the packet was not valid. +func (conn *Conn) receiveSplitPacket(p *packet) error { + const maxSplitCount = 256 + if (p.splitCount > maxSplitCount || len(conn.splits) > maxSplitCount) && conn.limits { + return fmt.Errorf("split count %v (%v active) exceeds the maximum %v", p.splitCount, len(conn.splits), maxSplitCount) + } + m, ok := conn.splits[p.splitID] + if !ok { + m = make([][]byte, p.splitCount) + conn.splits[p.splitID] = m + } + if p.splitIndex > uint32(len(m)-1) { + // The split index was either negative or was bigger than the slice size, meaning the packet is + // invalid. + return fmt.Errorf("error handing split packet: split index %v is out of range (0 - %v)", p.splitIndex, len(m)-1) + } + m[p.splitIndex] = p.content + + size := 0 + for _, fragment := range m { + if len(fragment) == 0 { + // We haven't yet received all split fragments, so we cannot add the packets together yet. + return nil + } + // First we calculate the total size required to hold the content of the combined content. + size += len(fragment) + } + + content := make([]byte, 0, size) + for _, fragment := range m { + content = append(content, fragment...) + } + + delete(conn.splits, p.splitID) + + p.content = content + return conn.receivePacket(p) +} + +// sendACK sends an acknowledgement packet containing the packet sequence numbers passed. If not successful, +// an error is returned. +func (conn *Conn) sendACK(packets ...uint24) error { + defer conn.ackBuf.Reset() + return conn.sendAcknowledgement(packets, bitFlagACK, conn.ackBuf) +} + +// sendNACK sends an acknowledgement packet containing the packet sequence numbers passed. If not successful, +// an error is returned. +func (conn *Conn) sendNACK(packets []uint24) error { + defer conn.nackBuf.Reset() + return conn.sendAcknowledgement(packets, bitFlagNACK, conn.nackBuf) +} + +// sendAcknowledgement sends an acknowledgement packet with the packets passed, potentially sending multiple +// if too many packets are passed. The bitflag is added to the header byte. +func (conn *Conn) sendAcknowledgement(packets []uint24, bitflag byte, buf *bytes.Buffer) error { + ack := &acknowledgement{packets: packets} + + for len(ack.packets) != 0 { + buf.WriteByte(bitflag | bitFlagDatagram) + n, err := ack.write(buf, conn.mtuSize) + if err != nil { + panic(fmt.Sprintf("error encoding ACK packet: %v", err)) + } + // We managed to write n packets in the ACK with this MTU size, write the next of the packets in a new ACK. + ack.packets = ack.packets[n:] + if _, err := conn.conn.WriteTo(buf.Bytes(), conn.addr); err != nil { + return fmt.Errorf("error sending ACK packet: %v", err) + } + buf.Reset() + } + return nil +} + +// handleACK handles an acknowledgement packet from the other end of the connection. These mean that a +// datagram was successfully received by the other end. +func (conn *Conn) handleACK(b *bytes.Buffer) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + ack := &acknowledgement{} + if err := ack.read(b); err != nil { + return fmt.Errorf("error reading ACK: %v", err) + } + for _, sequenceNumber := range ack.packets { + // Take out all stored packets from the recovery queue. + p, ok := conn.retransmission.acknowledge(sequenceNumber) + if ok { + // Clear the packet and return it to the pool so that it may be re-used. + p.content = nil + packetPool.Put(p) + } + } + return nil +} + +// handleNACK handles a negative acknowledgment packet from the other end of the connection. These mean that a +// datagram was found missing. +func (conn *Conn) handleNACK(b *bytes.Buffer) error { + conn.mu.Lock() + defer conn.mu.Unlock() + + nack := &acknowledgement{} + if err := nack.read(b); err != nil { + return fmt.Errorf("error reading NACK: %v", err) + } + return conn.resend(nack.packets) +} + +// resend sends all datagrams currently in the recovery queue with the sequence numbers passed. +func (conn *Conn) resend(sequenceNumbers []uint24) (err error) { + for _, sequenceNumber := range sequenceNumbers { + pk, ok := conn.retransmission.retransmit(sequenceNumber) + if !ok { + // We could not resend this datagram. Maybe it was already resent before at the request of the + // client. This is generally expected so we just continue. + continue + } + + // We first write a new datagram header using a new send sequence number that we find. + if err := conn.buf.WriteByte(bitFlagDatagram | bitFlagNeedsBAndAS); err != nil { + return fmt.Errorf("error writing recovered datagram header: %v", err) + } + newSeqNum := conn.seq + conn.seq++ + writeUint24(conn.buf, newSeqNum) + pk.write(conn.buf) + + // We then send the pk to the connection. + if _, err := conn.conn.WriteTo(conn.buf.Bytes(), conn.addr); err != nil { + return fmt.Errorf("error sending pk to addr %v: %v", conn.addr, err) + } + // We then re-add the pk to the recovery queue in case the new one gets lost too, in which case + // we need to resend it again. + conn.retransmission.add(newSeqNum, pk) + conn.buf.Reset() + } + return nil +} + +// requestConnection requests the connection from the server, provided this connection operates as a client. +// An error occurs if the request was not successful. +func (conn *Conn) requestConnection(id int64) error { + b := bytes.NewBuffer(nil) + (&message.ConnectionRequest{ClientGUID: id, RequestTimestamp: timestamp()}).Write(b) + _, err := conn.Write(b.Bytes()) + return err +} diff --git a/datagram_window.go b/datagram_window.go new file mode 100644 index 0000000..dcc0d67 --- /dev/null +++ b/datagram_window.go @@ -0,0 +1,78 @@ +package raknet + +import ( + "time" +) + +// datagramWindow is a queue for incoming datagrams. +type datagramWindow struct { + lowest, highest uint24 + queue map[uint24]time.Time +} + +// newDatagramWindow returns a new initialised datagram window. +func newDatagramWindow() *datagramWindow { + return &datagramWindow{queue: make(map[uint24]time.Time)} +} + +// new checks if the index passed is new to the datagramWindow. +func (win *datagramWindow) new(index uint24) bool { + if index < win.lowest { + return true + } + _, ok := win.queue[index] + return !ok +} + +// add puts an index in the window. +func (win *datagramWindow) add(index uint24) { + if index >= win.highest { + win.highest = index + 1 + } + win.queue[index] = time.Now() +} + +// shift attempts to delete as many indices from the queue as possible, increasing the lowest index if and when +// possible. +func (win *datagramWindow) shift() (n int) { + var index uint24 + for index = win.lowest; index < win.highest; index++ { + if _, ok := win.queue[index]; !ok { + break + } + delete(win.queue, index) + n++ + } + win.lowest = index + return n +} + +// missing returns a slice of all indices in the datagram queue that weren't set using add while within the +// window of lowest and highest index. The queue is shifted after this call. +func (win *datagramWindow) missing(since time.Duration) (indices []uint24) { + var ( + missing = false + ) + for index := int(win.highest) - 1; index >= int(win.lowest); index-- { + i := uint24(index) + t, ok := win.queue[i] + if ok { + if time.Since(t) >= since { + // All packets before this one took too long to arrive, so we mark them as missing. + missing = true + } + continue + } + if missing { + indices = append(indices, i) + win.queue[i] = time.Time{} + } + } + win.shift() + return indices +} + +// size returns the size of the datagramWindow. +func (win *datagramWindow) size() uint24 { + return win.highest - win.lowest +} diff --git a/dial.go b/dial.go new file mode 100644 index 0000000..1575be3 --- /dev/null +++ b/dial.go @@ -0,0 +1,469 @@ +package raknet + +import ( + "bytes" + "context" + "fmt" + "log" + "math/rand" + "net" + "os" + "sync/atomic" + "time" + + "github.com/sandertv/go-raknet/internal/message" +) + +// UpstreamDialer is an interface for anything compatible with net.Dialer. +type UpstreamDialer interface { + Dial(network, address string) (net.Conn, error) +} + +// Ping sends a ping to an address and returns the response obtained. If successful, a non-nil response byte +// slice containing the data is returned. If the ping failed, an error is returned describing the failure. +// Note that the packet sent to the server may be lost due to the nature of UDP. If this is the case, an error +// is returned which implies a timeout occurred. +// Ping will timeout after 5 seconds. +func Ping(address string) (response []byte, err error) { + var d Dialer + return d.Ping(address) +} + +// PingTimeout sends a ping to an address and returns the response obtained. If successful, a non-nil response +// byte slice containing the data is returned. If the ping failed, an error is returned describing the +// failure. +// Note that the packet sent to the server may be lost due to the nature of UDP. If this is the case, an error +// is returned which implies a timeout occurred. +// PingTimeout will time out after the duration passed. +func PingTimeout(address string, timeout time.Duration) ([]byte, error) { + var d Dialer + return d.PingTimeout(address, timeout) +} + +// PingContext sends a ping to an address and returns the response obtained. If successful, a non-nil response +// byte slice containing the data is returned. If the ping failed, an error is returned describing the +// failure. +// Note that the packet sent to the server may be lost due to the nature of UDP. If this is the case, +// PingContext could last indefinitely, hence a timeout should always be attached to the context passed. +// PingContext cancels as soon as the deadline expires. +func PingContext(ctx context.Context, address string) (response []byte, err error) { + var d Dialer + return d.PingContext(ctx, address) +} + +// Dial attempts to dial a RakNet connection to the address passed. The address may be either an IP address +// or a hostname, combined with a port that is separated with ':'. +// Dial will attempt to dial a connection within 10 seconds. If not all packets are received after that, the +// connection will timeout and an error will be returned. +// Dial fills out a Dialer struct with a default error logger. +func Dial(address string) (*Conn, error) { + var d Dialer + return d.Dial(address) +} + +// DialTimeout attempts to dial a RakNet connection to the address passed. The address may be either an IP +// address or a hostname, combined with a port that is separated with ':'. +// DialTimeout will attempt to dial a connection within the timeout duration passed. If not all packets are +// received after that, the connection will timeout and an error will be returned. +func DialTimeout(address string, timeout time.Duration) (*Conn, error) { + var d Dialer + return d.DialTimeout(address, timeout) +} + +// DialContext attempts to dial a RakNet connection to the address passed. The address may be either an IP +// address or a hostname, combined with a port that is separated with ':'. +// DialContext will use the deadline (ctx.Deadline) of the context.Context passed for the maximum amount of +// time that the dialing can take. DialContext will terminate as soon as possible when the context.Context is +// closed. +func DialContext(ctx context.Context, address string) (*Conn, error) { + var d Dialer + return d.DialContext(ctx, address) +} + +// Dialer allows dialing a RakNet connection with specific configuration, such as the protocol version of the +// connection and the logger used. +type Dialer struct { + // ErrorLog is a logger that errors from packet decoding are logged to. It may be set to a logger that + // simply discards the messages. + ErrorLog *log.Logger + + // UpstreamDialer is a dialer that will override the default dialer for opening outgoing connections. + UpstreamDialer UpstreamDialer +} + +// Ping sends a ping to an address and returns the response obtained. If successful, a non-nil response byte +// slice containing the data is returned. If the ping failed, an error is returned describing the failure. +// Note that the packet sent to the server may be lost due to the nature of UDP. If this is the case, an error +// is returned which implies a timeout occurred. +// Ping will timeout after 5 seconds. +func (dialer Dialer) Ping(address string) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + return dialer.PingContext(ctx, address) +} + +// PingTimeout sends a ping to an address and returns the response obtained. If successful, a non-nil response +// byte slice containing the data is returned. If the ping failed, an error is returned describing the +// failure. +// Note that the packet sent to the server may be lost due to the nature of UDP. If this is the case, an error +// is returned which implies a timeout occurred. +// PingTimeout will time out after the duration passed. +func (dialer Dialer) PingTimeout(address string, timeout time.Duration) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return dialer.PingContext(ctx, address) +} + +// PingContext sends a ping to an address and returns the response obtained. If successful, a non-nil response +// byte slice containing the data is returned. If the ping failed, an error is returned describing the +// failure. +// Note that the packet sent to the server may be lost due to the nature of UDP. If this is the case, +// PingContext could last indefinitely, hence a timeout should always be attached to the context passed. +// PingContext cancels as soon as the deadline expires. +func (dialer Dialer) PingContext(ctx context.Context, address string) (response []byte, err error) { + var conn net.Conn + + if dialer.UpstreamDialer == nil { + conn, err = net.Dial("udp", address) + } else { + conn, err = dialer.UpstreamDialer.Dial("udp", address) + } + if err != nil { + return nil, &net.OpError{Op: "ping", Net: "raknet", Source: nil, Addr: nil, Err: err} + } + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-done: + case <-ctx.Done(): + _ = conn.Close() + } + }() + actual := func(e error) error { + if err := ctx.Err(); err != nil { + return err + } + return e + } + + buffer := bytes.NewBuffer(nil) + (&message.UnconnectedPing{SendTimestamp: timestamp(), ClientGUID: atomic.AddInt64(&dialerID, 1)}).Write(buffer) + if _, err := conn.Write(buffer.Bytes()); err != nil { + return nil, &net.OpError{Op: "ping", Net: "raknet", Source: nil, Addr: nil, Err: actual(err)} + } + buffer.Reset() + + data := make([]byte, 1492) + n, err := conn.Read(data) + if err != nil { + return nil, &net.OpError{Op: "ping", Net: "raknet", Source: nil, Addr: nil, Err: actual(err)} + } + data = data[:n] + + _, _ = buffer.Write(data) + if b, err := buffer.ReadByte(); err != nil || b != message.IDUnconnectedPong { + return nil, &net.OpError{Op: "ping", Net: "raknet", Source: nil, Addr: nil, Err: fmt.Errorf("non-pong packet found: %w", err)} + } + pong := &message.UnconnectedPong{} + if err := pong.Read(buffer); err != nil { + return nil, &net.OpError{Op: "ping", Net: "raknet", Source: nil, Addr: nil, Err: fmt.Errorf("invalid unconnected pong: %w", err)} + } + _ = conn.Close() + return pong.Data, nil +} + +// dialerID is a counter used to produce an ID for the client. +var dialerID = rand.New(rand.NewSource(time.Now().Unix())).Int63() + +// Dial attempts to dial a RakNet connection to the address passed. The address may be either an IP address +// or a hostname, combined with a port that is separated with ':'. +// Dial will attempt to dial a connection within 10 seconds. If not all packets are received after that, the +// connection will timeout and an error will be returned. +func (dialer Dialer) Dial(address string) (*Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + return dialer.DialContext(ctx, address) +} + +// DialTimeout attempts to dial a RakNet connection to the address passed. The address may be either an IP +// address or a hostname, combined with a port that is separated with ':'. +// DialTimeout will attempt to dial a connection within the timeout duration passed. If not all packets are +// received after that, the connection will timeout and an error will be returned. +func (dialer Dialer) DialTimeout(address string, timeout time.Duration) (*Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return dialer.DialContext(ctx, address) +} + +// DialContext attempts to dial a RakNet connection to the address passed. The address may be either an IP +// address or a hostname, combined with a port that is separated with ':'. +// DialContext will use the deadline (ctx.Deadline) of the context.Context passed for the maximum amount of +// time that the dialing can take. DialContext will terminate as soon as possible when the context.Context is +// closed. +func (dialer Dialer) DialContext(ctx context.Context, address string) (*Conn, error) { + var udpConn net.Conn + var err error + + if dialer.UpstreamDialer == nil { + udpConn, err = net.Dial("udp", address) + } else { + udpConn, err = dialer.UpstreamDialer.Dial("udp", address) + } + if err != nil { + return nil, &net.OpError{Op: "dial", Net: "raknet", Source: nil, Addr: nil, Err: err} + } + packetConn := udpConn.(net.PacketConn) + + if deadline, ok := ctx.Deadline(); ok { + _ = packetConn.SetDeadline(deadline) + } + + id := atomic.AddInt64(&dialerID, 1) + if dialer.ErrorLog == nil { + dialer.ErrorLog = log.New(os.Stderr, "", log.LstdFlags) + } + state := &connState{ + conn: udpConn, + remoteAddr: udpConn.RemoteAddr(), + discoveringMTUSize: 1492, + id: id, + } + wrap := func(ctx context.Context, err error) error { + return &net.OpError{Op: "dial", Net: "raknet", Source: nil, Addr: nil, Err: err} + } + + if err := state.discoverMTUSize(ctx); err != nil { + return nil, wrap(ctx, err) + } else if err := state.openConnectionRequest(ctx); err != nil { + return nil, wrap(ctx, err) + } + + conn := newConnWithLimits(&wrappedConn{PacketConn: packetConn}, udpConn.RemoteAddr(), uint16(atomic.LoadUint32(&state.mtuSize)), false) + conn.close = func() { + // We have to make the Conn call this method explicitly because it must not close the connection + // established by the Listener. (This would close the entire listener.) + _ = udpConn.Close() + } + if err := conn.requestConnection(id); err != nil { + return nil, wrap(ctx, err) + } + + go clientListen(conn, udpConn, dialer.ErrorLog) + select { + case <-conn.connected: + _ = packetConn.SetDeadline(time.Time{}) + return conn, nil + case <-ctx.Done(): + _ = conn.Close() + return nil, wrap(ctx, ctx.Err()) + } +} + +// wrappedCon wraps around a 'pre-connected' UDP connection. Its only purpose is to wrap around WriteTo and +// make it call Write instead. +type wrappedConn struct { + net.PacketConn +} + +// WriteTo wraps around net.PacketConn to replace functionality of WriteTo with Write. It is used to be able +// to re-use the functionality in raknet.Conn. +func (conn *wrappedConn) WriteTo(b []byte, _ net.Addr) (n int, err error) { + return conn.PacketConn.(net.Conn).Write(b) +} + +// clientListen makes the RakNet connection passed listen as a client for packets received in the connection +// passed. +func clientListen(rakConn *Conn, conn net.Conn, errorLog *log.Logger) { + // Create a buffer with the maximum size a UDP packet sent over RakNet is allowed to have. We can re-use + // this buffer for each packet. + b := make([]byte, 1500) + buf := bytes.NewBuffer(b[:0]) + for { + n, err := conn.Read(b) + if err != nil { + if ErrConnectionClosed(err) { + // The connection was closed, so we can return from the function without logging the error. + return + } + errorLog.Printf("client: error reading from Conn: %v", err) + return + } + buf.Write(b[:n]) + if err := rakConn.receive(buf); err != nil { + errorLog.Printf("error handling packet: %v\n", err) + } + buf.Reset() + } +} + +// connState represents a state of a connection before the connection is finalised. It holds some data +// collected during the connection. +type connState struct { + conn net.Conn + remoteAddr net.Addr + id int64 + + // mtuSize is the final MTU size found by sending open connection request 1 packets. It is the MTU size + // sent by the server. + mtuSize uint32 + + // discoveringMTUSize is the current MTU size 'discovered'. This MTU size decreases the more the open + // connection request 1 is sent, so that the max packet size can be discovered. + discoveringMTUSize uint16 +} + +// openConnectionRequest sends open connection request 2 packets continuously until it receives an open +// connection reply 2 packet from the server. +func (state *connState) openConnectionRequest(ctx context.Context) (e error) { + ticker := time.NewTicker(time.Second / 2) + defer ticker.Stop() + + stop := make(chan bool) + defer func() { + close(stop) + }() + // Use an intermediate channel to start the ticker immediately. + c := make(chan struct{}, 1) + c <- struct{}{} + go func() { + for { + select { + case <-c: + if err := state.sendOpenConnectionRequest2(uint16(atomic.LoadUint32(&state.mtuSize))); err != nil { + e = err + return + } + case <-ticker.C: + c <- struct{}{} + case <-stop: + return + case <-ctx.Done(): + _ = state.conn.Close() + return + } + } + }() + + b := make([]byte, 1492) + for { + // Start reading in a loop so that we can find open connection reply 2 packets. + n, err := state.conn.Read(b) + if err != nil { + return err + } + buffer := bytes.NewBuffer(b[:n]) + id, err := buffer.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet ID: %v", err) + } + if id != message.IDOpenConnectionReply2 { + // We got a packet, but the packet was not an open connection reply 2 packet. We simply discard it + // and continue reading. + continue + } + reply := &message.OpenConnectionReply2{} + if err := reply.Read(buffer); err != nil { + return fmt.Errorf("error reading open connection reply 2: %v", err) + } + atomic.StoreUint32(&state.mtuSize, uint32(reply.MTUSize)) + return + } +} + +// discoverMTUSize starts discovering an MTU size, the maximum packet size we can send, by sending multiple +// open connection request 1 packets to the server with a decreasing MTU size padding. +func (state *connState) discoverMTUSize(ctx context.Context) (e error) { + ticker := time.NewTicker(time.Second / 2) + defer ticker.Stop() + var staticMTU uint16 + + stop := make(chan struct{}) + defer func() { + close(stop) + }() + // Use an intermediate channel to start the ticker immediately. + c := make(chan struct{}, 1) + c <- struct{}{} + go func() { + for { + select { + case <-c: + mtu := state.discoveringMTUSize + if staticMTU != 0 { + mtu = staticMTU + } + if err := state.sendOpenConnectionRequest1(mtu); err != nil { + e = err + return + } + if staticMTU == 0 { + // Each half second we decrease the MTU size by 40. This means that in 10 seconds, we have an MTU + // size of 692. This is a little above the actual RakNet minimum, but that should not be an issue. + state.discoveringMTUSize -= 40 + } + case <-ticker.C: + c <- struct{}{} + case <-stop: + return + case <-ctx.Done(): + _ = state.conn.Close() + return + } + } + }() + + b := make([]byte, 1492) + for { + // Start reading in a loop so that we can find open connection reply 1 packets. + n, err := state.conn.Read(b) + if err != nil { + return err + } + buffer := bytes.NewBuffer(b[:n]) + id, err := buffer.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet ID: %v", err) + } + switch id { + case message.IDOpenConnectionReply1: + response := &message.OpenConnectionReply1{} + if err := response.Read(buffer); err != nil { + return fmt.Errorf("error reading open connection reply 1: %v", err) + } + if response.ServerPreferredMTUSize < 400 || response.ServerPreferredMTUSize > 1500 { + // This is an awful hack we cooked up to deal with OVH 'DDoS' protection. For some reason they + // send a broken MTU size first. Sending a Request2 followed by a Request1 deals with this. + _ = state.sendOpenConnectionRequest2(response.ServerPreferredMTUSize) + staticMTU = state.discoveringMTUSize + 40 + continue + } + atomic.StoreUint32(&state.mtuSize, uint32(response.ServerPreferredMTUSize)) + return + case message.IDIncompatibleProtocolVersion: + response := &message.IncompatibleProtocolVersion{} + if err := response.Read(buffer); err != nil { + return fmt.Errorf("error reading incompatible protocol version: %v", err) + } + return fmt.Errorf("mismatched protocol: client protocol = %v, server protocol = %v", currentProtocol, response.ServerProtocol) + } + } +} + +// sendOpenConnectionRequest2 sends an open connection request 2 packet to the server. If not successful, an +// error is returned. +func (state *connState) sendOpenConnectionRequest2(mtu uint16) error { + b := bytes.NewBuffer(nil) + (&message.OpenConnectionRequest2{ServerAddress: *state.remoteAddr.(*net.UDPAddr), ClientPreferredMTUSize: mtu, ClientGUID: state.id}).Write(b) + _, err := state.conn.Write(b.Bytes()) + return err +} + +// sendOpenConnectionRequest1 sends an open connection request 1 packet to the server. If not successful, an +// error is returned. +func (state *connState) sendOpenConnectionRequest1(mtu uint16) error { + b := bytes.NewBuffer(nil) + (&message.OpenConnectionRequest1{Protocol: currentProtocol, MaximumSizeNotDropped: mtu}).Write(b) + _, err := state.conn.Write(b.Bytes()) + return err +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 0000000..dddadcb --- /dev/null +++ b/dial_test.go @@ -0,0 +1,94 @@ +package raknet_test + +import ( + "net" + "strings" + "testing" + + "github.com/sandertv/go-raknet" +) + +func TestPing(t *testing.T) { + //noinspection SpellCheckingInspection + const ( + addr = "mco.mineplex.com:19132" + prefix = "MCPE" + ) + + data, err := raknet.Ping(addr) + if err != nil { + t.Fatalf("error pinging %v: %v", addr, err) + } + str := string(data) + if !strings.HasPrefix(str, prefix) { + t.Fatalf("ping data should have prefix %v, but got %v", prefix, str) + } +} + +func TestPingWithCustomDialer(t *testing.T) { + //noinspection SpellCheckingInspection + const ( + addr = "mco.mineplex.com:19132" + prefix = "MCPE" + ) + + localDialAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:55556") + if err != nil { + t.Fatalf("error resolving local dial address: %v", err) + } + + dialer := raknet.Dialer{ + UpstreamDialer: &net.Dialer{ + LocalAddr: localDialAddr, + }, + } + + data, err := dialer.Ping(addr) + if err != nil { + t.Fatalf("error pinging %v: %v", addr, err) + } + str := string(data) + if !strings.HasPrefix(str, prefix) { + t.Fatalf("ping data should have prefix %v, but got %v", prefix, str) + } +} + +func TestDial(t *testing.T) { + //noinspection SpellCheckingInspection + const ( + addr = "mco.mineplex.com:19132" + ) + + conn, err := raknet.Dial(addr) + if err != nil { + t.Fatalf("error connecting to %v: %v", addr, err) + } + if err := conn.Close(); err != nil { + t.Fatalf("error closing connection: %v", err) + } +} + +func TestDialWithCustomDialer(t *testing.T) { + //noinspection SpellCheckingInspection + const ( + addr = "mco.mineplex.com:19132" + ) + + localDialAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:55555") + if err != nil { + t.Fatalf("error resolving local dial address: %v", err) + } + + dialer := raknet.Dialer{ + UpstreamDialer: &net.Dialer{ + LocalAddr: localDialAddr, + }, + } + conn, err := dialer.Dial(addr) + if err != nil { + t.Fatalf("error connecting to %v: %v", addr, err) + } + if err := conn.Close(); err != nil { + t.Fatalf("error closing connection: %v", err) + } +} diff --git a/err.go b/err.go new file mode 100644 index 0000000..d8e35bf --- /dev/null +++ b/err.go @@ -0,0 +1,36 @@ +package raknet + +import ( + "errors" + "net" + "strings" +) + +var ( + errBufferTooSmall = errors.New("a message sent was larger than the buffer used to receive the message into") + errListenerClosed = errors.New("use of closed listener") +) + +// ErrConnectionClosed checks if the error passed was an error caused by reading from a Conn of which the +// connection was closed. +func ErrConnectionClosed(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), net.ErrClosed.Error()) +} + +// wrap wraps the error passed into a net.OpError with the op as operation and returns it, or nil if the error +// passed is nil. +func (conn *Conn) wrap(err error, op string) error { + if err == nil { + return nil + } + return &net.OpError{ + Op: op, + Net: "raknet", + Source: conn.LocalAddr(), + Addr: conn.RemoteAddr(), + Err: err, + } +} diff --git a/example_dial_test.go b/example_dial_test.go new file mode 100644 index 0000000..d7bf879 --- /dev/null +++ b/example_dial_test.go @@ -0,0 +1,51 @@ +package raknet_test + +import ( + "fmt" + "github.com/sandertv/go-raknet" +) + +func ExamplePing() { + const address = "mco.mineplex.com:19132" + + // Ping the target address. This will ping with a timeout of 5 seconds. raknet.PingContext and + // raknet.PingTimeout may be used to cancel at any other time. + data, err := raknet.Ping(address) + if err != nil { + panic("error pinging " + address + ": " + err.Error()) + } + str := string(data) + + fmt.Println(str[:4]) + // Output: MCPE +} + +func ExampleDial() { + const address = "mco.mineplex.com:19132" + + // Dial a connection to the target address. This will time out after up to 10 seconds. raknet.DialTimeout + // and raknet.DialContext may be used to cancel at any other time. + conn, err := raknet.Dial(address) + if err != nil { + panic("error connecting to " + address + ": " + err.Error()) + } + + // Read a packet from the connection dialed. + p := make([]byte, 1500) + n, err := conn.Read(p) + if err != nil { + panic("error reading packet from " + address + ": " + err.Error()) + } + p = p[:n] + + // Write a packet to the connection. + data := []byte("Hello World!") + if _, err := conn.Write(data); err != nil { + panic("error writing packet to " + address + ": " + err.Error()) + } + + // Close the connection after you're done with it. + if err := conn.Close(); err != nil { + panic("error closing connection: " + err.Error()) + } +} diff --git a/example_listen_test.go b/example_listen_test.go new file mode 100644 index 0000000..e8384a2 --- /dev/null +++ b/example_listen_test.go @@ -0,0 +1,43 @@ +package raknet_test + +import ( + "github.com/sandertv/go-raknet" +) + +func ExampleListen() { + const address = ":19132" + + // Start listening on an address. + l, err := raknet.Listen(address) + if err != nil { + panic(err) + } + + for { + // Accept a new connection from the Listener. Accept will only return an error if the Listener is + // closed. (So only after a call to Listener.Close.) + conn, err := l.Accept() + if err != nil { + return + } + + // Read a packet from the connection accepted. + p := make([]byte, 1500) + n, err := conn.Read(p) + if err != nil { + panic("error reading packet from " + conn.RemoteAddr().String() + ": " + err.Error()) + } + p = p[:n] + + // Write a packet to the connection. + data := []byte("Hello World!") + if _, err := conn.Write(data); err != nil { + panic("error writing packet to " + conn.RemoteAddr().String() + ": " + err.Error()) + } + + // Close the connection after you're done with it. + if err := conn.Close(); err != nil { + panic("error closing connection: " + err.Error()) + } + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..eddcaed --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/sandertv/go-raknet + +go 1.18 + +require ( + github.com/df-mc/atomic v1.10.0 + golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..143bed7 --- /dev/null +++ b/go.sum @@ -0,0 +1,13 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/df-mc/atomic v1.10.0 h1:0ZuxBKwR/hxcFGorKiHIp+hY7hgY+XBTzhCYD2NqSEg= +github.com/df-mc/atomic v1.10.0/go.mod h1:Gw9rf+rPIbydMjA329Jn4yjd/O2c/qusw3iNp4tFGSc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1 h1:4qWs8cYYH6PoEFy4dfhDFgoMGkwAcETd+MmPdCPMzUc= +golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/message/connected_ping.go b/internal/message/connected_ping.go new file mode 100644 index 0000000..a975e40 --- /dev/null +++ b/internal/message/connected_ping.go @@ -0,0 +1,19 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type ConnectedPing struct { + ClientTimestamp int64 +} + +func (pk *ConnectedPing) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectedPing) + _ = binary.Write(buf, binary.BigEndian, pk.ClientTimestamp) +} + +func (pk *ConnectedPing) Read(buf *bytes.Buffer) error { + return binary.Read(buf, binary.BigEndian, &pk.ClientTimestamp) +} diff --git a/internal/message/connected_pong.go b/internal/message/connected_pong.go new file mode 100644 index 0000000..4ee4fdd --- /dev/null +++ b/internal/message/connected_pong.go @@ -0,0 +1,22 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type ConnectedPong struct { + ClientTimestamp int64 + ServerTimestamp int64 +} + +func (pk *ConnectedPong) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectedPong) + _ = binary.Write(buf, binary.BigEndian, pk.ClientTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.ServerTimestamp) +} + +func (pk *ConnectedPong) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.ClientTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.ServerTimestamp) +} diff --git a/internal/message/connection_request.go b/internal/message/connection_request.go new file mode 100644 index 0000000..d98b120 --- /dev/null +++ b/internal/message/connection_request.go @@ -0,0 +1,25 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type ConnectionRequest struct { + ClientGUID int64 + RequestTimestamp int64 + Secure bool +} + +func (pk *ConnectionRequest) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectionRequest) + _ = binary.Write(buf, binary.BigEndian, pk.ClientGUID) + _ = binary.Write(buf, binary.BigEndian, pk.RequestTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.Secure) +} + +func (pk *ConnectionRequest) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.ClientGUID) + _ = binary.Read(buf, binary.BigEndian, &pk.RequestTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.Secure) +} diff --git a/internal/message/connection_request_accepted.go b/internal/message/connection_request_accepted.go new file mode 100644 index 0000000..7ce1569 --- /dev/null +++ b/internal/message/connection_request_accepted.go @@ -0,0 +1,38 @@ +package message + +import ( + "bytes" + "encoding/binary" + "net" +) + +type ConnectionRequestAccepted struct { + ClientAddress net.UDPAddr + SystemAddresses [20]net.UDPAddr + RequestTimestamp int64 + AcceptedTimestamp int64 +} + +func (pk *ConnectionRequestAccepted) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDConnectionRequestAccepted) + writeAddr(buf, pk.ClientAddress) + _ = binary.Write(buf, binary.BigEndian, int16(0)) + for _, addr := range pk.SystemAddresses { + writeAddr(buf, addr) + } + _ = binary.Write(buf, binary.BigEndian, pk.RequestTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.AcceptedTimestamp) +} + +func (pk *ConnectionRequestAccepted) Read(buf *bytes.Buffer) error { + _ = readAddr(buf, &pk.ClientAddress) + buf.Next(2) + for i := 0; i < 20; i++ { + _ = readAddr(buf, &pk.SystemAddresses[i]) + if buf.Len() == 16 { + break + } + } + _ = binary.Read(buf, binary.BigEndian, &pk.RequestTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.AcceptedTimestamp) +} diff --git a/internal/message/incompatible_protocol_version.go b/internal/message/incompatible_protocol_version.go new file mode 100644 index 0000000..8d819fb --- /dev/null +++ b/internal/message/incompatible_protocol_version.go @@ -0,0 +1,25 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type IncompatibleProtocolVersion struct { + Magic [16]byte + ServerProtocol byte + ServerGUID int64 +} + +func (pk *IncompatibleProtocolVersion) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDIncompatibleProtocolVersion) + _ = binary.Write(buf, binary.BigEndian, pk.ServerProtocol) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) +} + +func (pk *IncompatibleProtocolVersion) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.ServerProtocol) + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + return binary.Read(buf, binary.BigEndian, &pk.ServerGUID) +} diff --git a/internal/message/new_incoming_connection.go b/internal/message/new_incoming_connection.go new file mode 100644 index 0000000..c8135b7 --- /dev/null +++ b/internal/message/new_incoming_connection.go @@ -0,0 +1,36 @@ +package message + +import ( + "bytes" + "encoding/binary" + "net" +) + +type NewIncomingConnection struct { + ServerAddress net.UDPAddr + SystemAddresses [20]net.UDPAddr + RequestTimestamp int64 + AcceptedTimestamp int64 +} + +func (pk *NewIncomingConnection) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDNewIncomingConnection) + writeAddr(buf, pk.ServerAddress) + for _, addr := range pk.SystemAddresses { + writeAddr(buf, addr) + } + _ = binary.Write(buf, binary.BigEndian, pk.RequestTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.AcceptedTimestamp) +} + +func (pk *NewIncomingConnection) Read(buf *bytes.Buffer) error { + _ = readAddr(buf, &pk.ServerAddress) + for i := 0; i < 20; i++ { + _ = readAddr(buf, &pk.SystemAddresses[i]) + if buf.Len() == 16 { + break + } + } + _ = binary.Read(buf, binary.BigEndian, &pk.RequestTimestamp) + return binary.Read(buf, binary.BigEndian, &pk.AcceptedTimestamp) +} diff --git a/internal/message/open_connection_reply_1.go b/internal/message/open_connection_reply_1.go new file mode 100644 index 0000000..f489251 --- /dev/null +++ b/internal/message/open_connection_reply_1.go @@ -0,0 +1,28 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type OpenConnectionReply1 struct { + Magic [16]byte + ServerGUID int64 + Secure bool + ServerPreferredMTUSize uint16 +} + +func (pk *OpenConnectionReply1) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionReply1) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) + _ = binary.Write(buf, binary.BigEndian, pk.Secure) + _ = binary.Write(buf, binary.BigEndian, pk.ServerPreferredMTUSize) +} + +func (pk *OpenConnectionReply1) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = binary.Read(buf, binary.BigEndian, &pk.ServerGUID) + _ = binary.Read(buf, binary.BigEndian, &pk.Secure) + return binary.Read(buf, binary.BigEndian, &pk.ServerPreferredMTUSize) +} diff --git a/internal/message/open_connection_reply_2.go b/internal/message/open_connection_reply_2.go new file mode 100644 index 0000000..83b0db5 --- /dev/null +++ b/internal/message/open_connection_reply_2.go @@ -0,0 +1,32 @@ +package message + +import ( + "bytes" + "encoding/binary" + "net" +) + +type OpenConnectionReply2 struct { + Magic [16]byte + ServerGUID int64 + ClientAddress net.UDPAddr + MTUSize uint16 + Secure bool +} + +func (pk *OpenConnectionReply2) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionReply2) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) + writeAddr(buf, pk.ClientAddress) + _ = binary.Write(buf, binary.BigEndian, pk.MTUSize) + _ = binary.Write(buf, binary.BigEndian, pk.Secure) +} + +func (pk *OpenConnectionReply2) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = binary.Read(buf, binary.BigEndian, &pk.ServerGUID) + _ = readAddr(buf, &pk.ClientAddress) + _ = binary.Read(buf, binary.BigEndian, &pk.MTUSize) + return binary.Read(buf, binary.BigEndian, &pk.Secure) +} diff --git a/internal/message/open_connection_request_1.go b/internal/message/open_connection_request_1.go new file mode 100644 index 0000000..77d6dae --- /dev/null +++ b/internal/message/open_connection_request_1.go @@ -0,0 +1,25 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type OpenConnectionRequest1 struct { + Magic [16]byte + Protocol byte + MaximumSizeNotDropped uint16 +} + +func (pk *OpenConnectionRequest1) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionRequest1) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.Protocol) + _, _ = buf.Write(make([]byte, pk.MaximumSizeNotDropped-uint16(buf.Len()+28))) +} + +func (pk *OpenConnectionRequest1) Read(buf *bytes.Buffer) error { + pk.MaximumSizeNotDropped = uint16(buf.Len()+1) + 28 + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + return binary.Read(buf, binary.BigEndian, &pk.Protocol) +} diff --git a/internal/message/open_connection_request_2.go b/internal/message/open_connection_request_2.go new file mode 100644 index 0000000..0402caa --- /dev/null +++ b/internal/message/open_connection_request_2.go @@ -0,0 +1,29 @@ +package message + +import ( + "bytes" + "encoding/binary" + "net" +) + +type OpenConnectionRequest2 struct { + Magic [16]byte + ServerAddress net.UDPAddr + ClientPreferredMTUSize uint16 + ClientGUID int64 +} + +func (pk *OpenConnectionRequest2) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDOpenConnectionRequest2) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + writeAddr(buf, pk.ServerAddress) + _ = binary.Write(buf, binary.BigEndian, pk.ClientPreferredMTUSize) + _ = binary.Write(buf, binary.BigEndian, pk.ClientGUID) +} + +func (pk *OpenConnectionRequest2) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = readAddr(buf, &pk.ServerAddress) + _ = binary.Read(buf, binary.BigEndian, &pk.ClientPreferredMTUSize) + return binary.Read(buf, binary.BigEndian, &pk.ClientGUID) +} diff --git a/internal/message/packet.go b/internal/message/packet.go new file mode 100644 index 0000000..1520516 --- /dev/null +++ b/internal/message/packet.go @@ -0,0 +1,95 @@ +package message + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" +) + +const ( + IDConnectedPing byte = 0x00 + IDUnconnectedPing byte = 0x01 + IDUnconnectedPingOpenConnections byte = 0x02 + IDConnectedPong byte = 0x03 + IDDetectLostConnections byte = 0x04 + IDOpenConnectionRequest1 byte = 0x05 + IDOpenConnectionReply1 byte = 0x06 + IDOpenConnectionRequest2 byte = 0x07 + IDOpenConnectionReply2 byte = 0x08 + IDConnectionRequest byte = 0x09 + IDConnectionRequestAccepted byte = 0x10 + IDNewIncomingConnection byte = 0x13 + IDDisconnectNotification byte = 0x15 + + IDIncompatibleProtocolVersion byte = 0x19 + + IDUnconnectedPong byte = 0x1c +) + +// unconnectedMessageSequence is a sequence of bytes which is found in every unconnected message sent in +// RakNet. +var unconnectedMessageSequence = [16]byte{0x00, 0xff, 0xff, 0x00, 0xfe, 0xfe, 0xfe, 0xfe, 0xfd, 0xfd, 0xfd, 0xfd, 0x12, 0x34, 0x56, 0x78} + +// writeAddr writes a UDP address to the buffer passed. +func writeAddr(buffer *bytes.Buffer, addr net.UDPAddr) { + var ver byte = 6 + if addr.IP.To4() != nil { + ver = 4 + } + if addr.IP == nil { + addr.IP = make([]byte, 16) + } + _ = buffer.WriteByte(ver) + if ver == 4 { + ipBytes := addr.IP.To4() + + _ = buffer.WriteByte(^ipBytes[0]) + _ = buffer.WriteByte(^ipBytes[1]) + _ = buffer.WriteByte(^ipBytes[2]) + _ = buffer.WriteByte(^ipBytes[3]) + _ = binary.Write(buffer, binary.BigEndian, uint16(addr.Port)) + } else { + _ = binary.Write(buffer, binary.LittleEndian, int16(23)) // syscall.AF_INET6 on Windows. + _ = binary.Write(buffer, binary.BigEndian, uint16(addr.Port)) + // The IPv6 address is enclosed in two 0 integers. + _ = binary.Write(buffer, binary.BigEndian, int32(0)) + _, _ = buffer.Write(addr.IP.To16()) + _ = binary.Write(buffer, binary.BigEndian, int32(0)) + } +} + +// readAddr decodes a RakNet address from the buffer passed. If not successful, an error is returned. +func readAddr(buffer *bytes.Buffer, addr *net.UDPAddr) error { + ver, err := buffer.ReadByte() + if err != nil { + return err + } + if ver == 4 { + ipBytes := make([]byte, 4) + if _, err := buffer.Read(ipBytes); err != nil { + return fmt.Errorf("error reading raknet address ipv4 bytes: %v", err) + } + // Construct an IPv4 out of the 4 bytes we just read. + addr.IP = net.IPv4((-ipBytes[0]-1)&0xff, (-ipBytes[1]-1)&0xff, (-ipBytes[2]-1)&0xff, (-ipBytes[3]-1)&0xff) + var port uint16 + if err := binary.Read(buffer, binary.BigEndian, &port); err != nil { + return fmt.Errorf("error reading raknet address port: %v", err) + } + addr.Port = int(port) + } else { + buffer.Next(2) + var port uint16 + if err := binary.Read(buffer, binary.LittleEndian, &port); err != nil { + return fmt.Errorf("error reading raknet address port: %v", err) + } + addr.Port = int(port) + buffer.Next(4) + addr.IP = make([]byte, 16) + if _, err := buffer.Read(addr.IP); err != nil { + return fmt.Errorf("error reading raknet address ipv6 bytes: %v", err) + } + buffer.Next(4) + } + return nil +} diff --git a/internal/message/unconnected_ping.go b/internal/message/unconnected_ping.go new file mode 100644 index 0000000..aab80c2 --- /dev/null +++ b/internal/message/unconnected_ping.go @@ -0,0 +1,25 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type UnconnectedPing struct { + Magic [16]byte + SendTimestamp int64 + ClientGUID int64 +} + +func (pk *UnconnectedPing) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDUnconnectedPing) + _ = binary.Write(buf, binary.BigEndian, pk.SendTimestamp) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, pk.ClientGUID) +} + +func (pk *UnconnectedPing) Read(buf *bytes.Buffer) error { + _ = binary.Read(buf, binary.BigEndian, &pk.SendTimestamp) + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + return binary.Read(buf, binary.BigEndian, &pk.ClientGUID) +} diff --git a/internal/message/unconnected_pong.go b/internal/message/unconnected_pong.go new file mode 100644 index 0000000..90e79b8 --- /dev/null +++ b/internal/message/unconnected_pong.go @@ -0,0 +1,33 @@ +package message + +import ( + "bytes" + "encoding/binary" +) + +type UnconnectedPong struct { + Magic [16]byte + SendTimestamp int64 + ServerGUID int64 + Data []byte +} + +func (pk *UnconnectedPong) Write(buf *bytes.Buffer) { + _ = binary.Write(buf, binary.BigEndian, IDUnconnectedPong) + _ = binary.Write(buf, binary.BigEndian, pk.SendTimestamp) + _ = binary.Write(buf, binary.BigEndian, pk.ServerGUID) + _ = binary.Write(buf, binary.BigEndian, unconnectedMessageSequence) + _ = binary.Write(buf, binary.BigEndian, int16(len(pk.Data))) + _ = binary.Write(buf, binary.BigEndian, pk.Data) +} + +func (pk *UnconnectedPong) Read(buf *bytes.Buffer) error { + var l int16 + _ = binary.Read(buf, binary.BigEndian, &pk.SendTimestamp) + _ = binary.Read(buf, binary.BigEndian, &pk.ServerGUID) + _ = binary.Read(buf, binary.BigEndian, &pk.Magic) + _ = binary.Read(buf, binary.BigEndian, &l) + pk.Data = make([]byte, l) + _, err := buf.Read(pk.Data) + return err +} diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..16f8e89 --- /dev/null +++ b/listener.go @@ -0,0 +1,296 @@ +package raknet + +import ( + "bytes" + "fmt" + "github.com/df-mc/atomic" + "github.com/sandertv/go-raknet/internal/message" + "log" + "math" + "math/rand" + "net" + "os" + "sync" + "time" +) + +// UpstreamPacketListener allows for a custom PacketListener implementation. +type UpstreamPacketListener interface { + ListenPacket(network, address string) (net.PacketConn, error) +} + +// ListenConfig may be used to pass additional configuration to a Listener. +type ListenConfig struct { + // ErrorLog is a logger that errors from packet decoding are logged to. It may be set to a logger that + // simply discards the messages. + ErrorLog *log.Logger + + // UpstreamPacketListener adds an abstraction for net.ListenPacket. + UpstreamPacketListener UpstreamPacketListener +} + +// Listener implements a RakNet connection listener. It follows the same methods as those implemented by the +// TCPListener in the net package. +// Listener implements the net.Listener interface. +type Listener struct { + once sync.Once + closed chan struct{} + + // log is a logger that errors from packet decoding are logged to. It may be set to a logger that + // simply discards the messages. + log *log.Logger + + conn net.PacketConn + // incoming is a channel of incoming connections. Connections that end up in here will also end up in + // the connections map. + incoming chan *Conn + + // connections is a map of currently active connections, indexed by their address. + connections sync.Map + + // id is a random server ID generated upon starting listening. It is used several times throughout the + // connection sequence of RakNet. + id int64 + + // pongData is a byte slice of data that is sent in an unconnected pong packet each time the client sends + // and unconnected ping to the server. + pongData atomic.Value[[]byte] +} + +// listenerID holds the next ID to use for a Listener. +var listenerID = atomic.NewInt64(rand.New(rand.NewSource(time.Now().Unix())).Int63()) + +// Listen listens on the address passed and returns a listener that may be used to accept connections. If not +// successful, an error is returned. +// The address follows the same rules as those defined in the net.TCPListen() function. +// Specific features of the listener may be modified once it is returned, such as the used log and/or the +// accepted protocol. +func (l ListenConfig) Listen(address string) (*Listener, error) { + var conn net.PacketConn + var err error + + if l.UpstreamPacketListener == nil { + conn, err = net.ListenPacket("udp", address) + } else { + conn, err = l.UpstreamPacketListener.ListenPacket("udp", address) + } + if err != nil { + return nil, &net.OpError{Op: "listen", Net: "raknet", Source: nil, Addr: nil, Err: err} + } + listener := &Listener{ + conn: conn, + incoming: make(chan *Conn), + closed: make(chan struct{}), + log: log.New(os.Stderr, "", log.LstdFlags), + id: listenerID.Inc(), + } + if l.ErrorLog != nil { + listener.log = l.ErrorLog + } + + go listener.listen() + return listener, nil +} + +// Listen listens on the address passed and returns a listener that may be used to accept connections. If not +// successful, an error is returned. +// The address follows the same rules as those defined in the net.TCPListen() function. +// Specific features of the listener may be modified once it is returned, such as the used log and/or the +// accepted protocol. +func Listen(address string) (*Listener, error) { + var lc ListenConfig + return lc.Listen(address) +} + +// Accept blocks until a connection can be accepted by the listener. If successful, Accept returns a +// connection that is ready to send and receive data. If not successful, a nil listener is returned and an +// error describing the problem. +func (listener *Listener) Accept() (net.Conn, error) { + conn, ok := <-listener.incoming + if !ok { + return nil, &net.OpError{Op: "accept", Net: "raknet", Source: nil, Addr: nil, Err: errListenerClosed} + } + return conn, nil +} + +// Addr returns the address the Listener is bound to and listening for connections on. +func (listener *Listener) Addr() net.Addr { + return listener.conn.LocalAddr() +} + +// Close closes the listener so that it may be cleaned up. It makes sure the goroutine handling incoming +// packets is able to be freed. +func (listener *Listener) Close() error { + var err error + listener.once.Do(func() { + close(listener.closed) + err = listener.conn.Close() + }) + return err +} + +// PongData sets the pong data that is used to respond with when a client sends a ping. It usually holds game +// specific data that is used to display in a server list. +// If a data slice is set with a size bigger than math.MaxInt16, the function panics. +func (listener *Listener) PongData(data []byte) { + if len(data) > math.MaxInt16 { + panic(fmt.Sprintf("error setting pong data: pong data must not be longer than %v", math.MaxInt16)) + } + listener.pongData.Store(data) +} + +// ID returns the unique ID of the listener. This ID is usually used by a client to identify a specific +// server during a single session. +func (listener *Listener) ID() int64 { + return listener.id +} + +// listen continuously reads from the listener's UDP connection, until closed has a value in it. +func (listener *Listener) listen() { + // Create a buffer with the maximum size a UDP packet sent over RakNet is allowed to have. We can re-use + // this buffer for each packet. + b := make([]byte, 1500) + buf := bytes.NewBuffer(b[:0]) + for { + n, addr, err := listener.conn.ReadFrom(b) + if err != nil { + close(listener.incoming) + return + } + _, _ = buf.Write(b[:n]) + + // Technically we should not re-use the same byte slice after its ownership has been taken by the + // buffer, but we can do this anyway because we copy the data later. + if err := listener.handle(buf, addr); err != nil { + listener.log.Printf("listener: error handling packet (addr = %v): %v\n", addr, err) + } + buf.Reset() + } +} + +// handle handles an incoming packet in buffer b from the address passed. If not successful, an error is +// returned describing the issue. +func (listener *Listener) handle(b *bytes.Buffer, addr net.Addr) error { + value, found := listener.connections.Load(addr.String()) + if !found { + // If there was no session yet, it means the packet is an offline message. It is not contained in a + // datagram. + packetID, err := b.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet ID byte: %v", err) + } + switch packetID { + case message.IDUnconnectedPing, message.IDUnconnectedPingOpenConnections: + return listener.handleUnconnectedPing(b, addr) + case message.IDOpenConnectionRequest1: + return listener.handleOpenConnectionRequest1(b, addr) + case message.IDOpenConnectionRequest2: + return listener.handleOpenConnectionRequest2(b, addr) + default: + // In some cases, the client will keep trying to send datagrams while it has already timed out. In + // this case, we should not print an error. + if packetID&bitFlagDatagram == 0 { + return fmt.Errorf("unknown packet received (%x): %x", packetID, b.Bytes()) + } + } + return nil + } + conn := value.(*Conn) + select { + case <-conn.closed: + // Connection was closed already. + return nil + default: + err := conn.receive(b) + if err != nil { + conn.closeImmediately() + } + return err + } +} + +// handleOpenConnectionRequest2 handles an open connection request 2 packet stored in buffer b, coming from +// an address addr. +func (listener *Listener) handleOpenConnectionRequest2(b *bytes.Buffer, addr net.Addr) error { + packet := &message.OpenConnectionRequest2{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading open connection request 2: %v", err) + } + b.Reset() + + mtuSize := packet.ClientPreferredMTUSize + if mtuSize > maxMTUSize { + mtuSize = maxMTUSize + } + + (&message.OpenConnectionReply2{ServerGUID: listener.id, ClientAddress: *addr.(*net.UDPAddr), MTUSize: mtuSize}).Write(b) + if _, err := listener.conn.WriteTo(b.Bytes(), addr); err != nil { + return fmt.Errorf("error sending open connection reply 2: %v", err) + } + + conn := newConn(listener.conn, addr, packet.ClientPreferredMTUSize) + conn.close = func() { + // Make sure to remove the connection from the Listener once the Conn is closed. + listener.connections.Delete(addr.String()) + } + listener.connections.Store(addr.String(), conn) + + go func() { + t := time.NewTimer(time.Second * 10) + defer t.Stop() + select { + case <-conn.connected: + // Add the connection to the incoming channel so that a caller of Accept() can receive it. + listener.incoming <- conn + case <-listener.closed: + _ = conn.Close() + case <-t.C: + // It took too long to complete this connection. We closed it and go back to accepting. + _ = conn.Close() + } + }() + + return nil +} + +// handleOpenConnectionRequest1 handles an open connection request 1 packet stored in buffer b, coming from +// an address addr. +func (listener *Listener) handleOpenConnectionRequest1(b *bytes.Buffer, addr net.Addr) error { + packet := &message.OpenConnectionRequest1{} + if err := packet.Read(b); err != nil { + return fmt.Errorf("error reading open connection request 1: %v", err) + } + b.Reset() + mtuSize := packet.MaximumSizeNotDropped + if mtuSize > maxMTUSize { + mtuSize = maxMTUSize + } + + if packet.Protocol != currentProtocol { + (&message.IncompatibleProtocolVersion{ServerGUID: listener.id, ServerProtocol: currentProtocol}).Write(b) + _, _ = listener.conn.WriteTo(b.Bytes(), addr) + return fmt.Errorf("error handling open connection request 1: incompatible protocol version %v (listener protocol = %v)", packet.Protocol, currentProtocol) + } + + (&message.OpenConnectionReply1{ServerGUID: listener.id, Secure: false, ServerPreferredMTUSize: mtuSize}).Write(b) + _, err := listener.conn.WriteTo(b.Bytes(), addr) + return err +} + +// handleUnconnectedPing handles an unconnected ping packet stored in buffer b, coming from an address addr. +func (listener *Listener) handleUnconnectedPing(b *bytes.Buffer, addr net.Addr) error { + pk := &message.UnconnectedPing{} + if err := pk.Read(b); err != nil { + return fmt.Errorf("error reading unconnected ping: %v", err) + } + b.Reset() + + (&message.UnconnectedPong{ServerGUID: listener.id, SendTimestamp: pk.SendTimestamp, Data: listener.pongData.Load()}).Write(b) + _, err := listener.conn.WriteTo(b.Bytes(), addr) + return err +} + +// timestamp returns a timestamp in milliseconds. +func timestamp() int64 { + return time.Now().UnixNano() / int64(time.Second) +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 0000000..4e2447f --- /dev/null +++ b/listener_test.go @@ -0,0 +1,36 @@ +package raknet_test + +import ( + "fmt" + "github.com/sandertv/go-raknet" + "testing" + "time" +) + +func TestListen(t *testing.T) { + l, err := raknet.Listen(":19132") + if err != nil { + panic(err) + } + go func() { + _, _ = raknet.Dial("127.0.0.1:19132") + }() + c := make(chan error) + go accept(l, c) + + select { + case err := <-c: + if err != nil { + t.Error(err) + } + case <-time.After(time.Second * 3): + t.Errorf("accepting connection took longer than 3 seconds") + } +} + +func accept(l *raknet.Listener, c chan error) { + if _, err := l.Accept(); err != nil { + c <- fmt.Errorf("error accepting connection: %v", err) + } + c <- nil +} diff --git a/packet.go b/packet.go new file mode 100644 index 0000000..215b01c --- /dev/null +++ b/packet.go @@ -0,0 +1,315 @@ +package raknet + +import ( + "bytes" + "encoding/binary" + "fmt" + "sort" +) + +const ( + // bitFlagDatagram is set for every valid datagram. It is used to identify packets that are datagrams. + bitFlagDatagram = 0x80 + // bitFlagACK is set for every ACK packet. + bitFlagACK = 0x40 + // bitFlagNACK is set for every NACK packet. + bitFlagNACK = 0x20 + // bitFlagNeedsBAndAS is set for every datagram with packet data, but is not + // actually used. + bitFlagNeedsBAndAS = 0x04 +) + +//noinspection GoUnusedConst +const ( + // reliabilityUnreliable means that the packet sent could arrive out of order, be duplicated, or just not + // arrive at all. It is usually used for high frequency packets of which the order does not matter. + //lint:ignore U1000 While this constant is unused, it is here for the sake of having all reliabilities. + reliabilityUnreliable byte = iota + // reliabilityUnreliableSequenced means that the packet sent could be duplicated or not arrive at all, but + // ensures that it is always handled in the right order. + reliabilityUnreliableSequenced + // reliabilityReliable means that the packet sent could not arrive, or arrive out of order, but ensures + // that the packet is not duplicated. + reliabilityReliable + // reliabilityReliableOrdered means that every packet sent arrives, arrives in the right order and is not + // duplicated. + reliabilityReliableOrdered + // reliabilityReliableSequenced means that the packet sent could not arrive, but ensures that the packet + // will be in the right order and not be duplicated. + reliabilityReliableSequenced + + // splitFlag is set in the header if the packet was split. If so, the encapsulation contains additional + // data about the fragment. + splitFlag = 0x10 +) + +// packet is an encapsulation around every packet sent after the connection is established. It is +type packet struct { + reliability byte + + content []byte + messageIndex uint24 + sequenceIndex uint24 + orderIndex uint24 + + split bool + splitCount uint32 + splitIndex uint32 + splitID uint16 +} + +// write writes the packet and its content to the buffer passed. +func (packet *packet) write(b *bytes.Buffer) { + header := packet.reliability << 5 + if packet.split { + header |= splitFlag + } + b.WriteByte(header) + _ = binary.Write(b, binary.BigEndian, uint16(len(packet.content))<<3) + if packet.reliable() { + writeUint24(b, packet.messageIndex) + } + if packet.sequenced() { + writeUint24(b, packet.sequenceIndex) + } + if packet.sequencedOrOrdered() { + writeUint24(b, packet.orderIndex) + // Order channel, we don't care about this. + b.WriteByte(0) + } + if packet.split { + _ = binary.Write(b, binary.BigEndian, packet.splitCount) + _ = binary.Write(b, binary.BigEndian, packet.splitID) + _ = binary.Write(b, binary.BigEndian, packet.splitIndex) + } + b.Write(packet.content) +} + +// read reads a packet and its content from the buffer passed. +func (packet *packet) read(b *bytes.Buffer) error { + header, err := b.ReadByte() + if err != nil { + return fmt.Errorf("error reading packet header: %v", err) + } + packet.split = (header & splitFlag) != 0 + packet.reliability = (header & 224) >> 5 + var packetLength uint16 + if err := binary.Read(b, binary.BigEndian, &packetLength); err != nil { + return fmt.Errorf("error reading packet length: %v", err) + } + packetLength >>= 3 + if packetLength == 0 { + return fmt.Errorf("invalid packet length: cannot be 0") + } + + if packet.reliable() { + packet.messageIndex, err = readUint24(b) + if err != nil { + return fmt.Errorf("error reading packet message index: %v", err) + } + } + + if packet.sequenced() { + packet.sequenceIndex, err = readUint24(b) + if err != nil { + return fmt.Errorf("error reading packet sequence index: %v", err) + } + } + + if packet.sequencedOrOrdered() { + packet.orderIndex, err = readUint24(b) + if err != nil { + return fmt.Errorf("error reading packet order index: %v", err) + } + // Order channel (byte), we don't care about this. + b.Next(1) + } + + if packet.split { + if err := binary.Read(b, binary.BigEndian, &packet.splitCount); err != nil { + return fmt.Errorf("error reading packet split count: %v", err) + } + if err := binary.Read(b, binary.BigEndian, &packet.splitID); err != nil { + return fmt.Errorf("error reading packet split ID: %v", err) + } + if err := binary.Read(b, binary.BigEndian, &packet.splitIndex); err != nil { + return fmt.Errorf("error reading packet split index: %v", err) + } + } + + packet.content = make([]byte, packetLength) + if n, err := b.Read(packet.content); err != nil || n != int(packetLength) { + return fmt.Errorf("not enough data in packet: %v bytes read but need %v", n, packetLength) + } + return nil +} + +func (packet *packet) reliable() bool { + switch packet.reliability { + case reliabilityReliable, + reliabilityReliableOrdered, + reliabilityReliableSequenced: + return true + } + return false +} + +func (packet *packet) sequencedOrOrdered() bool { + switch packet.reliability { + case reliabilityUnreliableSequenced, + reliabilityReliableOrdered, + reliabilityReliableSequenced: + return true + } + return false +} + +func (packet *packet) sequenced() bool { + switch packet.reliability { + case reliabilityUnreliableSequenced, + reliabilityReliableSequenced: + return true + } + return false +} + +const ( + // packetRange indicates a range of packets, followed by the first and the last packet in the range. + packetRange = iota + // packetSingle indicates a single packet, followed by its sequence number. + packetSingle +) + +// acknowledgement is an acknowledgement packet that may either be an ACK or a NACK, depending on the purpose +// that it is sent with. +type acknowledgement struct { + packets []uint24 +} + +// write encodes an acknowledgement packet and returns an error if not successful. +func (ack *acknowledgement) write(b *bytes.Buffer, mtu uint16) (n int, err error) { + packets := ack.packets + if len(packets) == 0 { + return 0, binary.Write(b, binary.BigEndian, int16(0)) + } + buffer := bytes.NewBuffer(nil) + // Sort packets before encoding to ensure packets are encoded correctly. + sort.Slice(packets, func(i, j int) bool { + return packets[i] < packets[j] + }) + + var firstPacketInRange uint24 + var lastPacketInRange uint24 + var recordCount int16 + + for index, packet := range packets { + if buffer.Len() >= int(mtu-10) { + // We must make sure the final packet length doesn't exceed the MTU size. + break + } + n++ + if index == 0 { + // The first packet, set the first and last packet to it. + firstPacketInRange = packet + lastPacketInRange = packet + continue + } + if packet == lastPacketInRange+1 { + // Packet is still part of the current range, as it's sequenced properly with the last packet. + // Set the last packet in range to the packet and continue to the next packet. + lastPacketInRange = packet + continue + } else { + // We got to the end of a range/single packet. We need to write those down now. + if firstPacketInRange == lastPacketInRange { + // First packet equals last packet, so we have a single packet record. Write down the packet, + // and set the first and last packet to the current packet. + if err := buffer.WriteByte(packetSingle); err != nil { + return 0, err + } + writeUint24(buffer, firstPacketInRange) + + firstPacketInRange = packet + lastPacketInRange = packet + } else { + // There's a gap between the first and last packet, so we have a range of packets. Write the + // first and last packet of the range and set both to the current packet. + if err := buffer.WriteByte(packetRange); err != nil { + return 0, err + } + writeUint24(buffer, firstPacketInRange) + writeUint24(buffer, lastPacketInRange) + + firstPacketInRange = packet + lastPacketInRange = packet + } + // Keep track of the amount of records as we need to write that first. + recordCount++ + } + } + + // Make sure the last single packet/range is written, as we always need to know one packet ahead to know + // how we should write the current. + if firstPacketInRange == lastPacketInRange { + if err := buffer.WriteByte(packetSingle); err != nil { + return 0, err + } + writeUint24(buffer, firstPacketInRange) + } else { + if err := buffer.WriteByte(packetRange); err != nil { + return 0, err + } + writeUint24(buffer, firstPacketInRange) + writeUint24(buffer, lastPacketInRange) + } + recordCount++ + if err := binary.Write(b, binary.BigEndian, recordCount); err != nil { + return 0, err + } + if _, err := b.Write(buffer.Bytes()); err != nil { + return 0, err + } + return n, nil +} + +// read decodes an acknowledgement packet and returns an error if not successful. +func (ack *acknowledgement) read(b *bytes.Buffer) error { + const maxAcknowledgementPackets = 8192 + var recordCount int16 + if err := binary.Read(b, binary.BigEndian, &recordCount); err != nil { + return err + } + for i := int16(0); i < recordCount; i++ { + recordType, err := b.ReadByte() + if err != nil { + return err + } + switch recordType { + case packetRange: + start, err := readUint24(b) + if err != nil { + return err + } + end, err := readUint24(b) + if err != nil { + return err + } + for pack := start; pack <= end; pack++ { + ack.packets = append(ack.packets, pack) + if len(ack.packets) > maxAcknowledgementPackets { + return fmt.Errorf("maximum amount of packets in acknowledgement exceeded") + } + } + case packetSingle: + packet, err := readUint24(b) + if err != nil { + return err + } + ack.packets = append(ack.packets, packet) + if len(ack.packets) > maxAcknowledgementPackets { + return fmt.Errorf("maximum amount of packets in acknowledgement exceeded") + } + } + } + return nil +} diff --git a/packet_queue.go b/packet_queue.go new file mode 100644 index 0000000..7496ba0 --- /dev/null +++ b/packet_queue.go @@ -0,0 +1,50 @@ +package raknet + +// packetQueue is an ordered queue for reliable ordered packets. +type packetQueue struct { + lowest uint24 + highest uint24 + queue map[uint24][]byte +} + +// newPacketQueue returns a new initialised ordered queue. +func newPacketQueue() *packetQueue { + return &packetQueue{queue: make(map[uint24][]byte)} +} + +// put puts a value at the index passed. If the index was already occupied once, false is returned. +func (queue *packetQueue) put(index uint24, packet []byte) bool { + if index < queue.lowest { + return false + } + if _, ok := queue.queue[index]; ok { + return false + } + if index >= queue.highest { + queue.highest = index + 1 + } + queue.queue[index] = packet + return true +} + +// fetch attempts to take out as many values from the ordered queue as possible. Upon encountering an index +// that has no value yet, the function returns all values that it did find and takes them out. +func (queue *packetQueue) fetch() (packets [][]byte) { + index := queue.lowest + for index < queue.highest { + packet, ok := queue.queue[index] + if !ok { + break + } + delete(queue.queue, index) + packets = append(packets, packet) + index++ + } + queue.lowest = index + return +} + +// WindowSize returns the size of the window held by the packet queue. +func (queue *packetQueue) WindowSize() uint24 { + return queue.highest - queue.lowest +} diff --git a/resend_map.go b/resend_map.go new file mode 100644 index 0000000..3153560 --- /dev/null +++ b/resend_map.go @@ -0,0 +1,80 @@ +package raknet + +import ( + "time" +) + +// resendMap is a map of packets, used to recover datagrams if the other end of the connection ended up +// not having them. +type resendMap struct { + unacknowledged map[uint24]resendRecord + delays map[time.Time]time.Duration +} + +// resendRecord represents a single packet with a timestamp from when it was initially sent. It may be either +// acknowledged or NACKed by the other end. +type resendRecord struct { + pk *packet + timestamp time.Time +} + +// newRecoveryQueue returns a new initialised recovery queue. +func newRecoveryQueue() *resendMap { + return &resendMap{ + delays: make(map[time.Time]time.Duration), + unacknowledged: make(map[uint24]resendRecord), + } +} + +// add puts a packet at the index passed and records the current time. +func (m *resendMap) add(index uint24, pk *packet) { + m.unacknowledged[index] = resendRecord{pk: pk, timestamp: time.Now()} +} + +// acknowledge marks a packet with the index passed as acknowledged. The packet is removed from the resendMap and +// returned if found. +func (m *resendMap) acknowledge(index uint24) (*packet, bool) { + return m.remove(index, 1) +} + +// retransmit looks up a packet with an index from the resendMap so that it may be resent. +func (m *resendMap) retransmit(index uint24) (*packet, bool) { + return m.remove(index, 2) +} + +// remove deletes an index from the resendMap and adds the time since the packet was originally sent multiplied by mul +// to the delays slice. +func (m *resendMap) remove(index uint24, mul int) (*packet, bool) { + record, ok := m.unacknowledged[index] + if !ok { + return nil, false + } + delete(m.unacknowledged, index) + + now := time.Now() + m.delays[now] = now.Sub(record.timestamp) * time.Duration(mul) + return record.pk, true +} + +// rtt returns the average round trip time between the putting of the value into the recovery queue and the taking +// out of it again. It is measured over the last delayRecordCount values add in. +func (m *resendMap) rtt() time.Duration { + const averageDuration = time.Second * 5 + var ( + total, records time.Duration + now = time.Now() + ) + for t, rtt := range m.delays { + if now.Sub(t) > averageDuration { + delete(m.delays, t) + continue + } + total += rtt + records++ + } + if records == 0 { + // No records yet, generally should not happen. Just return a reasonable amount of time. + return time.Millisecond * 50 + } + return total / records +}