From 6440f010eec82abb9c999771a8f493af44c6b937 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 21 Mar 2019 14:43:04 -0600 Subject: receive: implement flush semantics --- device/boundif_darwin.go | 2 +- device/boundif_windows.go | 2 +- device/conn.go | 2 +- device/queueconstants_android.go | 2 +- device/receive.go | 204 ++++++++++++++++++++++----------------- tun/operateonfd.go | 24 +++++ tun/tun.go | 1 + tun/tun_darwin.go | 5 + tun/tun_default.go | 24 ----- tun/tun_freebsd.go | 5 + tun/tun_linux.go | 5 + tun/tun_openbsd.go | 5 + tun/tun_windows.go | 12 ++- 13 files changed, 171 insertions(+), 122 deletions(-) create mode 100644 tun/operateonfd.go delete mode 100644 tun/tun_default.go diff --git a/device/boundif_darwin.go b/device/boundif_darwin.go index b3d10ba..a93441c 100644 --- a/device/boundif_darwin.go +++ b/device/boundif_darwin.go @@ -41,4 +41,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error { return err } return nil -} \ No newline at end of file +} diff --git a/device/boundif_windows.go b/device/boundif_windows.go index 00631cb..97381ad 100644 --- a/device/boundif_windows.go +++ b/device/boundif_windows.go @@ -53,4 +53,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error { return err } return nil -} \ No newline at end of file +} diff --git a/device/conn.go b/device/conn.go index 2594680..3c2aa04 100644 --- a/device/conn.go +++ b/device/conn.go @@ -177,4 +177,4 @@ func (device *Device) BindClose() error { err := unsafeCloseBind(device) device.net.Unlock() return err -} \ No newline at end of file +} diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 8d051ad..f5c042d 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -13,4 +13,4 @@ const ( QueueHandshakeSize = 1024 MaxSegmentSize = 2200 PreallocatedBuffersPerPool = 4096 -) \ No newline at end of file +) diff --git a/device/receive.go b/device/receive.go index 09fae59..747a188 100644 --- a/device/receive.go +++ b/device/receive.go @@ -482,6 +482,33 @@ func (device *Device) RoutineHandshake() { } } +func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) { + if !*shouldFlush { + select { + case <-peer.routines.stop: + stop = true + return + case elem, elemOk = <-peer.queue.inbound: + return + } + } else { + select { + case <-peer.routines.stop: + stop = true + return + case elem, elemOk = <-peer.queue.inbound: + return + default: + *shouldFlush = false + err := peer.device.tun.device.Flush() + if err != nil { + peer.device.log.Error.Printf("Unable to flush packets: %v", err) + } + return peer.elementStopOrFlush(shouldFlush) + } + } +} + func (peer *Peer) RoutineSequentialReceiver() { device := peer.device @@ -491,6 +518,9 @@ func (peer *Peer) RoutineSequentialReceiver() { var elem *QueueInboundElement var ok bool + var stop bool + + shouldFlush := false defer func() { logDebug.Println(peer, "- Routine: sequential receiver - stopped") @@ -516,126 +546,122 @@ func (peer *Peer) RoutineSequentialReceiver() { elem = nil } - select { - - case <-peer.routines.stop: + stop, ok, elem = peer.elementStopOrFlush(&shouldFlush) + if stop || !ok { return + } - case elem, ok = <-peer.queue.inbound: - - if !ok { - return - } - - // wait for decryption + // wait for decryption - elem.Lock() + elem.Lock() - if elem.IsDropped() { - continue - } + if elem.IsDropped() { + continue + } - // check for replay + // check for replay - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } + // check if using new keypair + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: } + } - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - - // check for keepalive + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue - } - peer.timersDataReceived() + // check for keepalive - // verify source and strip padding + if len(elem.packet) == 0 { + logDebug.Println(peer, "- Receiving keepalive packet") + continue + } + peer.timersDataReceived() - switch elem.packet[0] >> 4 { - case ipv4.Version: + // verify source and strip padding - // strip padding + switch elem.packet[0] >> 4 { + case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue - } + // strip padding - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue - } + if len(elem.packet) < ipv4.HeaderLen { + continue + } - elem.packet = elem.packet[:length] + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } - // verify IPv4 source + elem.packet = elem.packet[:length] - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue - } + // verify IPv4 source - case ipv6.Version: + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.LookupIPv4(src) != peer { + logInfo.Println( + "IPv4 packet with disallowed source address from", + peer, + ) + continue + } - // strip padding + case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - continue - } + // strip padding - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } + if len(elem.packet) < ipv6.HeaderLen { + continue + } - elem.packet = elem.packet[:length] + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } - // verify IPv6 source + elem.packet = elem.packet[:length] - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - peer, - "sent packet with disallowed IPv6 source", - ) - continue - } + // verify IPv6 source - default: - logInfo.Println("Packet with invalid IP version from", peer) + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.LookupIPv6(src) != peer { + logInfo.Println( + peer, + "sent packet with disallowed IPv6 source", + ) continue } - // write to tun device + default: + logInfo.Println("Packet with invalid IP version from", peer) + continue + } + + // write to tun device - offset := MessageTransportOffsetContent - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if err != nil && !device.isClosed.Get() { - logError.Println("Failed to write packet to TUN device:", err) - } + offset := MessageTransportOffsetContent + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + if err == nil { + shouldFlush = true + } + if err != nil && !device.isClosed.Get() { + logError.Println("Failed to write packet to TUN device:", err) } } } diff --git a/tun/operateonfd.go b/tun/operateonfd.go new file mode 100644 index 0000000..31747a2 --- /dev/null +++ b/tun/operateonfd.go @@ -0,0 +1,24 @@ +// +build !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "fmt" +) + +func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { + sysconn, err := tun.tunFile.SyscallConn() + if err != nil { + tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) + return + } + err = sysconn.Control(fn) + if err != nil { + tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) + } +} diff --git a/tun/tun.go b/tun/tun.go index c4b6cac..12febb8 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -21,6 +21,7 @@ type TUNDevice interface { File() *os.File // returns the file descriptor of the device Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) + Flush() error // flush all previous writes to the device MTU() (int, error) // returns the MTU of the device Name() (string, error) // fetches and returns the current name Events() chan TUNEvent // returns a constant channel of events related to the device diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index 3b39982..2077de3 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -281,6 +281,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { return tun.tunFile.Write(buff) } +func (tun *NativeTun) Flush() error { + //TODO: can flushing be implemented by buffering and using sendmmsg? + return nil +} + func (tun *NativeTun) Close() error { var err2 error err1 := tun.tunFile.Close() diff --git a/tun/tun_default.go b/tun/tun_default.go deleted file mode 100644 index 31747a2..0000000 --- a/tun/tun_default.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build !windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package tun - -import ( - "fmt" -) - -func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { - sysconn, err := tun.tunFile.SyscallConn() - if err != nil { - tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) - return - } - err = sysconn.Control(fn) - if err != nil { - tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) - } -} diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 3a60725..01a4348 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -406,6 +406,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { return tun.tunFile.Write(buff) } +func (tun *NativeTun) Flush() error { + //TODO: can flushing be implemented by buffering and using sendmmsg? + return nil +} + func (tun *NativeTun) Close() error { var err3 error err1 := tun.tunFile.Close() diff --git a/tun/tun_linux.go b/tun/tun_linux.go index b7c429c..784cb9f 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -318,6 +318,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { return tun.tunFile.Write(buff) } +func (tun *NativeTun) Flush() error { + //TODO: can flushing be implemented by buffering and using sendmmsg? + return nil +} + func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { select { case err := <-tun.errors: diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index 57edcb4..645bcca 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -237,6 +237,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { return tun.tunFile.Write(buff) } +func (tun *NativeTun) Flush() error { + //TODO: can flushing be implemented by buffering and using sendmmsg? + return nil +} + func (tun *NativeTun) Close() error { var err2 error err1 := tun.tunFile.Close() diff --git a/tun/tun_windows.go b/tun/tun_windows.go index dcb414a..fffd802 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -281,7 +281,11 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { // Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) flush() error { +func (tun *NativeTun) Flush() error { + if tun.wrBuff.offset == 0 { + return nil + } + // Get TUN data pipe. file, err := tun.getTUN() if err != nil { @@ -322,7 +326,7 @@ func (tun *NativeTun) putTunPacket(buff []byte) error { if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize { // Exchange buffer is full -> flush first. - err := tun.flush() + err := tun.Flush() if err != nil { return err } @@ -345,9 +349,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { if err != nil { return 0, err } - - // Flush write buffer. - return len(buff) - offset, tun.flush() + return len(buff) - offset, nil } // -- cgit v1.2.3