diff options
-rw-r--r-- | conn/bind_linux.go | 47 | ||||
-rw-r--r-- | conn/bind_std.go | 37 | ||||
-rw-r--r-- | conn/bind_windows.go | 67 | ||||
-rw-r--r-- | conn/bindtest/bindtest.go | 39 | ||||
-rw-r--r-- | conn/conn.go | 24 | ||||
-rw-r--r-- | conn/conn_test.go | 24 | ||||
-rw-r--r-- | device/channels.go | 30 | ||||
-rw-r--r-- | device/device.go | 27 | ||||
-rw-r--r-- | device/device_test.go | 69 | ||||
-rw-r--r-- | device/peer.go | 26 | ||||
-rw-r--r-- | device/pools.go | 32 | ||||
-rw-r--r-- | device/pools_test.go | 48 | ||||
-rw-r--r-- | device/receive.go | 320 | ||||
-rw-r--r-- | device/send.go | 270 | ||||
-rw-r--r-- | main.go | 14 | ||||
-rw-r--r-- | main_windows.go | 5 | ||||
-rw-r--r-- | tun/errors.go | 60 | ||||
-rw-r--r-- | tun/netstack/tun.go | 45 | ||||
-rw-r--r-- | tun/tun.go | 40 | ||||
-rw-r--r-- | tun/tun_darwin.go | 65 | ||||
-rw-r--r-- | tun/tun_freebsd.go | 53 | ||||
-rw-r--r-- | tun/tun_linux.go | 39 | ||||
-rw-r--r-- | tun/tun_openbsd.go | 58 | ||||
-rw-r--r-- | tun/tun_windows.go | 52 | ||||
-rw-r--r-- | tun/tuntest/tuntest.go | 29 |
25 files changed, 1026 insertions, 494 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go index bd710ae..b6bc0dc 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -193,6 +193,10 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error { return nil } +func (bind *LinuxSocketBind) BatchSize() int { + return 1 +} + func (bind *LinuxSocketBind) Close() error { // Take a readlock to shut down the sockets... bind.mu.RLock() @@ -223,29 +227,39 @@ func (bind *LinuxSocketBind) Close() error { return err2 } -func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() if bind.sock4 == -1 { - return 0, nil, net.ErrClosed + return 0, net.ErrClosed } var end LinuxSocketEndpoint - n, err := receive4(bind.sock4, buf, &end) - return n, &end, err + n, err := receive4(bind.sock4, buffs[0], &end) + if err != nil { + return 0, err + } + eps[0] = &end + sizes[0] = n + return 1, nil } -func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() if bind.sock6 == -1 { - return 0, nil, net.ErrClosed + return 0, net.ErrClosed } var end LinuxSocketEndpoint - n, err := receive6(bind.sock6, buf, &end) - return n, &end, err + n, err := receive6(bind.sock6, buffs[0], &end) + if err != nil { + return 0, err + } + eps[0] = &end + sizes[0] = n + return 1, nil } -func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { +func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error { nend, ok := end.(*LinuxSocketEndpoint) if !ok { return ErrWrongEndpointType @@ -256,13 +270,24 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { if bind.sock4 == -1 { return net.ErrClosed } - return send4(bind.sock4, nend, buff) + for _, buff := range buffs { + err := send4(bind.sock4, nend, buff) + if err != nil { + return err + } + } } else { if bind.sock6 == -1 { return net.ErrClosed } - return send6(bind.sock6, nend, buff) + for _, buff := range buffs { + err := send6(bind.sock6, nend, buff) + if err != nil { + return err + } + } } + return nil } func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { diff --git a/conn/bind_std.go b/conn/bind_std.go index ae07aac..98fe23c 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -128,6 +128,10 @@ again: return fns, uint16(port), nil } +func (bind *StdNetBind) BatchSize() int { + return 1 +} + func (bind *StdNetBind) Close() error { bind.mu.Lock() defer bind.mu.Unlock() @@ -150,20 +154,30 @@ func (bind *StdNetBind) Close() error { } func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buff []byte) (int, Endpoint, error) { - n, endpoint, err := conn.ReadFromUDPAddrPort(buff) - return n, asEndpoint(endpoint), err + return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) + if err == nil { + sizes[0] = size + eps[0] = asEndpoint(endpoint) + return 1, nil + } + return 0, err } } func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buff []byte) (int, Endpoint, error) { - n, endpoint, err := conn.ReadFromUDPAddrPort(buff) - return n, asEndpoint(endpoint), err + return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) + if err == nil { + sizes[0] = size + eps[0] = asEndpoint(endpoint) + return 1, nil + } + return 0, err } } -func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { +func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { var err error nend, ok := endpoint.(StdNetEndpoint) if !ok { @@ -186,8 +200,13 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - _, err = conn.WriteToUDPAddrPort(buff, addrPort) - return err + for _, buff := range buffs { + _, err = conn.WriteToUDPAddrPort(buff, addrPort) + if err != nil { + return err + } + } + return nil } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. diff --git a/conn/bind_windows.go b/conn/bind_windows.go index f8b187b..5a0b8c2 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -321,6 +321,11 @@ func (bind *WinRingBind) Close() error { return nil } +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + func (bind *WinRingBind) SetMark(mark uint32) error { return nil } @@ -409,16 +414,22 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v4.Receive(buf, &bind.isOpen) + n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } -func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v6.Receive(buf, &bind.isOpen) + n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { @@ -473,32 +484,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType } bind.mu.RLock() defer bind.mu.RUnlock() - switch nend.family { - case windows.AF_INET: - if bind.v4.blackhole { - return nil - } - return bind.v4.Send(buf, nend, &bind.isOpen) - case windows.AF_INET6: - if bind.v6.blackhole { - return nil + for _, buf := range buffs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } } - return bind.v6.Send(buf, nend, &bind.isOpen) } return nil } -func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -511,14 +528,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole4 = blackhole + s.blackhole4 = blackhole return nil } -func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() if err != nil { return err } @@ -531,7 +548,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole6 = blackhole + s.blackhole6 = blackhole return nil } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 9605a2a..b33c53d 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error { return nil } +func (c *ChannelBind) BatchSize() int { return 1 } + func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(b []byte) (n int, ep conn.Endpoint, err error) { + return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: - return 0, nil, net.ErrClosed + return 0, net.ErrClosed case rx := <-ch: - return copy(b, rx), c.target6, nil + copied := copy(buffs[0], rx) + sizes[0] = copied + eps[0] = c.target6 + return 1, nil } } } -func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { - select { - case <-c.closeSignal: - return net.ErrClosed - default: - bc := make([]byte, len(b)) - copy(bc, b) - if ep.(ChannelEndpoint) == c.target4 { - *c.tx4 <- bc - } else if ep.(ChannelEndpoint) == c.target6 { - *c.tx6 <- bc - } else { - return os.ErrInvalid +func (c *ChannelBind) Send(buffs [][]byte, ep conn.Endpoint) error { + for _, b := range buffs { + select { + case <-c.closeSignal: + return net.ErrClosed + default: + bc := make([]byte, len(b)) + copy(bc, b) + if ep.(ChannelEndpoint) == c.target4 { + *c.tx4 <- bc + } else if ep.(ChannelEndpoint) == c.target6 { + *c.tx6 <- bc + } else { + return os.ErrInvalid + } } } return nil diff --git a/conn/conn.go b/conn/conn.go index 497b92a..8c0a827 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -15,10 +15,17 @@ import ( "strings" ) -// A ReceiveFunc receives a single inbound packet from the network. -// It writes the data into b. n is the length of the packet. -// ep is the remote endpoint. -type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) +const ( + DefaultBatchSize = 1 // maximum number of packets handled per read and write +) + +// A ReceiveFunc receives at least one packet from the network and writes them +// into packets. On a successful read it returns the number of elements of +// sizes, packets, and endpoints that should be evaluated. Some elements of +// sizes may be zero, and callers should ignore them. Callers must pass a sizes +// and eps slice with a length greater than or equal to the length of packets. +// These lengths must not exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // @@ -38,11 +45,16 @@ type Bind interface { // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // Send writes a packet b to address ep. - Send(b []byte, ep Endpoint) error + // Send writes one or more packets in buffs to address ep. The length of + // buffs must not exceed BatchSize(). + Send(buffs [][]byte, ep Endpoint) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) + + // BatchSize is the number of buffers expected to be passed to + // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. + BatchSize() int } // BindSocketToInterface is implemented by Bind objects that support being diff --git a/conn/conn_test.go b/conn/conn_test.go new file mode 100644 index 0000000..7a6231d --- /dev/null +++ b/conn/conn_test.go @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "testing" +) + +func TestPrettyName(t *testing.T) { + var ( + recvFunc ReceiveFunc = func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + ) + + const want = "TestPrettyName" + + t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { + if got := recvFunc.PrettyName(); got != want { + t.Errorf("PrettyName() = %v, want %v", got, want) + } + }) +} diff --git a/device/channels.go b/device/channels.go index 1bfeeaf..039d8df 100644 --- a/device/channels.go +++ b/device/channels.go @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *QueueInboundElement + c chan *[]*QueueInboundElement } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *[]*QueueInboundElement, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elem := <-q.c: - elem.Lock() - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) + case elems := <-q.c: + for _, elem := range *elems { + elem.Lock() + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsSlice(elems) default: return } @@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *QueueOutboundElement + c chan *[]*QueueOutboundElement } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elem := <-q.c: - elem.Lock() - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + case elems := <-q.c: + for _, elem := range *elems { + elem.Lock() + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsSlice(elems) default: return } diff --git a/device/device.go b/device/device.go index 3368a93..091c8d4 100644 --- a/device/device.go +++ b/device/device.go @@ -68,9 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + outboundElementsSlice *WaitPool + inboundElementsSlice *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { @@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() device.indexTable.Init() + device.PopulatePools() // create queues @@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { return device } +// BatchSize returns the BatchSize for the device as a whole which is the max of +// the bind batch size and the tun batch size. The batch size reported by device +// is the size used to construct memory pools, and is the allowed batch size for +// the lifetime of the device. +func (device *Device) BatchSize() int { + size := device.net.bind.BatchSize() + dSize := device.tun.device.BatchSize() + if size < dSize { + size = dSize + } + return size +} + func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() @@ -472,11 +488,13 @@ func (device *Device) BindUpdate() error { var err error var recvFns []conn.ReceiveFunc netc := &device.net + recvFns, netc.port, err = netc.bind.Open(netc.port) if err != nil { netc.port = 0 return err } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) if err != nil { netc.bind.Close() @@ -507,8 +525,9 @@ func (device *Device) BindUpdate() error { device.net.stopping.Add(len(recvFns)) device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + batchSize := netc.bind.BatchSize() for _, fn := range recvFns { - go device.RoutineReceiveIncoming(fn) + go device.RoutineReceiveIncoming(batchSize, fn) } device.log.Verbosef("UDP bind has been updated") diff --git a/device/device_test.go b/device/device_test.go index 975da64..73891bf 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -12,6 +12,7 @@ import ( "io" "math/rand" "net/netip" + "os" "runtime" "runtime/pprof" "sync" @@ -21,6 +22,7 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" ) @@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) { } }) + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() + } + } + }) + close(done) } @@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) { t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) }) } + +type fakeBindSized struct { + size int +} + +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil +} +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(buffs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } + +type fakeTUNDeviceSized struct { + size int +} + +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(buffs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } +} diff --git a/device/peer.go b/device/peer.go index 0e7b669..0ac4896 100644 --- a/device/peer.go +++ b/device/peer.go @@ -45,9 +45,9 @@ type Peer struct { } queue struct { - staged chan *QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *[]*QueueOutboundElement // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] @@ -108,7 +108,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -123,9 +123,13 @@ func (peer *Peer) SendBuffer(buffer []byte) error { return errors.New("no known endpoint for peer") } - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffers, peer.endpoint) if err == nil { - peer.txBytes.Add(uint64(len(buffer))) + var totalLen uint64 + for _, b := range buffers { + totalLen += uint64(len(b)) + } + peer.txBytes.Add(totalLen) } return err } @@ -187,8 +191,12 @@ func (peer *Peer) Start() { device.flushInboundQueue(peer.queue.inbound) device.flushOutboundQueue(peer.queue.outbound) - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() + + // Use the device batch size, not the bind batch size, as the device size is + // the size of the batch pools. + batchSize := peer.device.BatchSize() + go peer.RoutineSequentialSender(batchSize) + go peer.RoutineSequentialReceiver(batchSize) peer.isRunning.Store(true) } diff --git a/device/pools.go b/device/pools.go index 239757f..02a5d6a 100644 --- a/device/pools.go +++ b/device/pools.go @@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) { } func (device *Device) PopulatePools() { + device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &s + }) + device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &s + }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) }) @@ -57,6 +65,30 @@ func (device *Device) PopulatePools() { }) } +func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { + return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +} + +func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { + for i := range *s { + (*s)[i] = nil + } + *s = (*s)[:0] + device.pool.outboundElementsSlice.Put(s) +} + +func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { + return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +} + +func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { + for i := range *s { + (*s)[i] = nil + } + *s = (*s)[:0] + device.pool.inboundElementsSlice.Put(s) +} + func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) } diff --git a/device/pools_test.go b/device/pools_test.go index 1502a29..82d7493 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) { } wg.Wait() } + +func BenchmarkWaitPoolEmpty(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(0, func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkSyncPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := sync.Pool{New: func() any { return make([]byte, 16) }} + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} diff --git a/device/receive.go b/device/receive.go index 03fcf00..aee7864 100644 --- a/device/receive.go +++ b/device/receive.go @@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { recvName := recv.PrettyName() defer func() { device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) @@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( + buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + buffs = make([][]byte, maxBatchSize) err error - size int - endpoint conn.Endpoint + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int + elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) ) - for { - size, endpoint, err = recv(buffer[:]) + for i := range buffsArrs { + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] + } + + defer func() { + for i := 0; i < maxBatchSize; i++ { + if buffsArrs[i] != nil { + device.PutMessageBuffer(buffsArrs[i]) + } + } + }() + for { + count, err = recv(buffs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) if errors.Is(err, net.ErrClosed) { return } @@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { if deathSpiral < 10 { deathSpiral++ time.Sleep(time.Second / 3) - buffer = device.GetMessageBuffer() continue } return } deathSpiral = 0 - if size < MinMessageSize { - continue - } + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - // check size of packet + // check size of packet - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + packet := buffsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - var okay bool + switch msgType { - switch msgType { + // check if transport - // check if transport + case MessageTransportType: - case MessageTransportType: + // check size - // check size + if len(packet) < MessageTransportSize { + continue + } - if len(packet) < MessageTransportSize { - continue - } + // lookup key pair - // lookup key pair + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + // check keypair expiry - // check keypair expiry + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + elem.Mutex = sync.Mutex{} + elem.Lock() + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsSlice() + elemsByPeer[peer] = elemsForPeer + } + *elemsForPeer = append(*elemsForPeer, elem) + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] continue - } - - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - // add to decryption queues - if peer.isRunning.Load() { - peer.queue.inbound.c <- elem - device.queue.decryption.c <- elem - buffer = device.GetMessageBuffer() - } else { - device.PutInboundElement(elem) - } - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - default: - device.log.Verbosef("Received message with unknown type") - } + default: + device.log.Verbosef("Received message with unknown type") + continue + } - if okay { select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: buffer, + buffer: buffsArrs[i], packet: packet, - endpoint: endpoint, + endpoint: endpoints[i], }: - buffer = device.GetMessageBuffer() + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] default: } } + for peer, elems := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elems + for _, elem := range *elems { + device.queue.decryption.c <- elem + } + } else { + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsSlice(elems) + } + delete(elemsByPeer, peer) + } } } @@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) { } } -func (peer *Peer) RoutineSequentialReceiver() { +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device defer func() { device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) @@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - for elem := range peer.queue.inbound.c { - if elem == nil { + buffs := make([][]byte, 0, maxBatchSize) + + for elems := range peer.queue.inbound.c { + if elems == nil { return } - var err error - elem.Lock() - if elem.packet == nil { - // decryption failed - goto skip - } + for _, elem := range *elems { + elem.Lock() + if elem.packet == nil { + // decryption failed + continue + } - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - goto skip - } + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } - peer.SetEndpointFromPacket(elem.endpoint) - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - peer.SendStagedPackets() - } + peer.SetEndpointFromPacket(elem.endpoint) + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + peer.SendStagedPackets() + } + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) + continue + } + peer.timersDataReceived() - if len(elem.packet) == 0 { - device.log.Verbosef("%v - Receiving keepalive packet", peer) - goto skip - } - peer.timersDataReceived() + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - goto skip - } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - goto skip - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.Lookup(src) != peer { - device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) - goto skip - } + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - goto skip - } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - goto skip - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.Lookup(src) != peer { - device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) - goto skip + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) + continue } - default: - device.log.Verbosef("Packet with invalid IP version from %v", peer) - goto skip + buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - - _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) - if err != nil && !device.isClosed() { - device.log.Errorf("Failed to write packet to TUN device: %v", err) - } - if len(peer.queue.inbound.c) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Errorf("Unable to flush packets: %v", err) + if len(buffs) > 0 { + _, err := device.tun.device.Write(buffs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - skip: - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + buffs = buffs[:0] + device.PutInboundElementsSlice(elems) } } diff --git a/device/send.go b/device/send.go index 854d172..b33b9f4 100644 --- a/device/send.go +++ b/device/send.go @@ -17,6 +17,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -77,12 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() + elems := peer.device.GetOutboundElementsSlice() + *elems = append(*elems, elem) select { - case peer.queue.staged <- elem: + case peer.queue.staged <- elems: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) + peer.device.PutOutboundElementsSlice(elems) } } peer.SendStagedPackets() @@ -125,7 +129,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -163,7 +167,8 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + // TODO: allocation could be avoided + err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -183,7 +188,8 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) var buff [MessageCookieReplySize]byte writer := bytes.NewBuffer(buff[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + // TODO: allocation could be avoided + device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } @@ -198,11 +204,6 @@ func (peer *Peer) keepKeyFreshSending() { } } -/* Reads packets from the TUN and inserts - * into staged queue for peer - * - * Obs. Single instance per TUN device - */ func (device *Device) RoutineReadFromTUN() { defer func() { device.log.Verbosef("Routine: TUN reader - stopped") @@ -212,81 +213,123 @@ func (device *Device) RoutineReadFromTUN() { device.log.Verbosef("Routine: TUN reader - started") - var elem *QueueOutboundElement + var ( + batchSize = device.BatchSize() + readErr error + elems = make([]*QueueOutboundElement, batchSize) + buffs = make([][]byte, batchSize) + elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + count = 0 + sizes = make([]int, batchSize) + offset = MessageTransportHeaderSize + ) + + for i := range elems { + elems[i] = device.NewOutboundElement() + buffs[i] = elems[i].buffer[:] + } - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + defer func() { + for _, elem := range elems { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } } - elem = device.NewOutboundElement() - - // read packet + }() - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) - if err != nil { - if !device.isClosed() { - if !errors.Is(err, os.ErrClosed) { - device.log.Errorf("Failed to read packet from TUN device: %v", err) - } - go device.Close() + for { + // read packets + count, readErr = device.tun.device.Read(buffs, sizes, offset) + for i := 0; i < count; i++ { + if sizes[i] < 1 { + continue } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - if size == 0 || size > MaxContentSize { - continue - } + elem := elems[i] + elem.packet = buffs[i][offset : offset+sizes[i]] - elem.packet = elem.buffer[offset : offset+size] + // lookup peer + var peer *Peer + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.Lookup(dst) - // lookup peer + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.Lookup(dst) - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue + default: + device.log.Verbosef("Received packet with unknown IP version") } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.Lookup(dst) - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { + if peer == nil { continue } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.Lookup(dst) - - default: - device.log.Verbosef("Received packet with unknown IP version") + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetOutboundElementsSlice() + elemsByPeer[peer] = elemsForPeer + } + *elemsForPeer = append(*elemsForPeer, elem) + elems[i] = device.NewOutboundElement() + buffs[i] = elems[i].buffer[:] } - if peer == nil { - continue + for peer, elemsForPeer := range elemsByPeer { + if peer.isRunning.Load() { + peer.StagePackets(elemsForPeer) + peer.SendStagedPackets() + } else { + for _, elem := range *elemsForPeer { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsSlice(elemsForPeer) + } + delete(elemsByPeer, peer) } - if peer.isRunning.Load() { - peer.StagePacket(elem) - elem = nil - peer.SendStagedPackets() + + if readErr != nil { + if errors.Is(readErr, tun.ErrTooManySegments) { + // TODO: record stat for this + // This will happen if MSS is surprisingly small (< 576) + // coincident with reasonably high throughput. + device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + continue + } + if !device.isClosed() { + if !errors.Is(readErr, os.ErrClosed) { + device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + } + go device.Close() + } + return } } } -func (peer *Peer) StagePacket(elem *QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { for { select { - case peer.queue.staged <- elem: + case peer.queue.staged <- elems: return default: } select { case tooOld := <-peer.queue.staged: - peer.device.PutMessageBuffer(tooOld.buffer) - peer.device.PutOutboundElement(tooOld) + for _, elem := range *tooOld { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsSlice(tooOld) default: } } @@ -305,26 +348,55 @@ top: } for { + var elemsOOO *[]*QueueOutboundElement select { - case elem := <-peer.queue.staged: - elem.peer = peer - elem.nonce = keypair.sendNonce.Add(1) - 1 - if elem.nonce >= RejectAfterMessages { - keypair.sendNonce.Store(RejectAfterMessages) - peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans - goto top + case elems := <-peer.queue.staged: + i := 0 + for _, elem := range *elems { + elem.peer = peer + elem.nonce = keypair.sendNonce.Add(1) - 1 + if elem.nonce >= RejectAfterMessages { + keypair.sendNonce.Store(RejectAfterMessages) + if elemsOOO == nil { + elemsOOO = peer.device.GetOutboundElementsSlice() + } + *elemsOOO = append(*elemsOOO, elem) + continue + } else { + (*elems)[i] = elem + i++ + } + + elem.keypair = keypair + elem.Lock() } + *elems = (*elems)[:i] - elem.keypair = keypair - elem.Lock() + if elemsOOO != nil { + peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + } + + if len(*elems) == 0 { + peer.device.PutOutboundElementsSlice(elems) + goto top + } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elem - peer.device.queue.encryption.c <- elem + peer.queue.outbound.c <- elems + for _, elem := range *elems { + peer.device.queue.encryption.c <- elem + } } else { - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + for _, elem := range *elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsSlice(elems) + } + + if elemsOOO != nil { + goto top } default: return @@ -335,9 +407,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elem := <-peer.queue.staged: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + case elems := <-peer.queue.staged: + for _, elem := range *elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsSlice(elems) default: return } @@ -400,12 +475,7 @@ func (device *Device) RoutineEncryption(id int) { } } -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { +func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device defer func() { defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) @@ -413,36 +483,50 @@ func (peer *Peer) RoutineSequentialSender() { }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) - for elem := range peer.queue.outbound.c { - if elem == nil { + buffs := make([][]byte, 0, maxBatchSize) + + for elems := range peer.queue.outbound.c { + buffs = buffs[:0] + if elems == nil { return } - elem.Lock() if !peer.isRunning.Load() { // peer has been stopped; return re-usable elems to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. - // The timers and SendBuffer code are resilient to a few stragglers. + // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + for _, elem := range *elems { + elem.Lock() + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } continue } + dataSent := false + for _, elem := range *elems { + elem.Lock() + if len(elem.packet) != MessageKeepaliveSize { + dataSent = true + } + buffs = append(buffs, elem.packet) + } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - // send message and return buffer to pool - - err := peer.SendBuffer(elem.packet) - if len(elem.packet) != MessageKeepaliveSize { + err := peer.SendBuffers(buffs) + if dataSent { peer.timersDataSent() } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsSlice(elems) if err != nil { - device.log.Errorf("%v - Failed to send data packet: %v", peer, err) + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue } @@ -13,8 +13,8 @@ import ( "os/signal" "runtime" "strconv" - "syscall" + "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" @@ -111,7 +111,7 @@ func main() { // open TUN device (or use supplied fd) - tun, err := func() (tun.Device, error) { + tdev, err := func() (tun.Device, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { return tun.CreateTUN(interfaceName, device.DefaultMTU) @@ -124,7 +124,7 @@ func main() { return nil, err } - err = syscall.SetNonblock(int(fd), true) + err = unix.SetNonblock(int(fd), true) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func main() { }() if err == nil { - realInterfaceName, err2 := tun.Name() + realInterfaceName, err2 := tdev.Name() if err2 == nil { interfaceName = realInterfaceName } @@ -196,7 +196,7 @@ func main() { files[0], // stdin files[1], // stdout files[2], // stderr - tun.File(), + tdev.File(), fileUAPI, }, Dir: ".", @@ -222,7 +222,7 @@ func main() { return } - device := device.NewDevice(tun, conn.NewDefaultBind(), logger) + device := device.NewDevice(tdev, conn.NewDefaultBind(), logger) logger.Verbosef("Device started") @@ -250,7 +250,7 @@ func main() { // wait for program to terminate - signal.Notify(term, syscall.SIGTERM) + signal.Notify(term, unix.SIGTERM) signal.Notify(term, os.Interrupt) select { diff --git a/main_windows.go b/main_windows.go index d075a60..a4dc46f 100644 --- a/main_windows.go +++ b/main_windows.go @@ -9,7 +9,8 @@ import ( "fmt" "os" "os/signal" - "syscall" + + "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" @@ -81,7 +82,7 @@ func main() { signal.Notify(term, os.Interrupt) signal.Notify(term, os.Kill) - signal.Notify(term, syscall.SIGTERM) + signal.Notify(term, windows.SIGTERM) select { case <-term: diff --git a/tun/errors.go b/tun/errors.go new file mode 100644 index 0000000..e70b13c --- /dev/null +++ b/tun/errors.go @@ -0,0 +1,60 @@ +package tun + +import ( + "errors" + "fmt" +) + +var ( + // ErrTooManySegments is returned by Device.Read() when segmentation + // overflows the length of supplied buffers. This error should not cause + // reads to cease. + ErrTooManySegments = errors.New("too many segments") +) + +type errorBatch []error + +// ErrorBatch takes a possibly nil or empty list of errors, and if the list is +// non-nil returns an error type that wraps all of the errors. Expected usage is +// to append to an []errors and coerce the set to an error using this method. +func ErrorBatch(errs []error) error { + if len(errs) == 0 { + return nil + } + return errorBatch(errs) +} + +func (e errorBatch) Error() string { + if len(e) == 0 { + return "" + } + if len(e) == 1 { + return e[0].Error() + } + return fmt.Sprintf("batch operation: %v (and %d more errors)", e[0], len(e)-1) +} + +func (e errorBatch) Is(target error) bool { + for _, err := range e { + if errors.Is(err, target) { + return true + } + } + return false +} + +func (e errorBatch) As(target interface{}) bool { + for _, err := range e { + if errors.As(err, target) { + return true + } + } + return false +} + +func (e errorBatch) Unwrap() error { + if len(e) == 0 { + return nil + } + return e[0] +} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 37c879d..a0b212a 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -19,6 +19,7 @@ import ( "regexp" "strconv" "strings" + "syscall" "time" "golang.zx2c4.com/wireguard/tun" @@ -113,29 +114,37 @@ func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(buf []byte, offset int) (int, error) { +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - return view.Read(buf[offset:]) + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil } -func (tun *netTun) Write(buf []byte, offset int) (int, error) { - packet := buf[offset:] - if len(packet) == 0 { - return 0, nil - } +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) - switch packet[0] >> 4 { - case 4: - tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) - case 6: - tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } } - return len(buf), nil } @@ -151,10 +160,6 @@ func (tun *netTun) WriteNotify() { tun.incomingPacket <- view } -func (tun *netTun) Flush() error { - return nil -} - func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) @@ -175,6 +180,10 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } +func (tun *netTun) BatchSize() int { + return 1 +} + func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { var protoNumber tcpip.NetworkProtocolNumber if endpoint.Addr().Is4() { @@ -18,12 +18,36 @@ const ( ) type Device interface { - File() *os.File // returns the file descriptor of the device - Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) - Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) - Flush() error // flush all previous writes to the device - MTU() (int, error) // returns the MTU of the device - Name() (string, error) // fetches and returns the current name - Events() <-chan Event // returns a constant channel of events related to the device - Close() error // stops the device and closes the event channel + // File returns the file descriptor of the device. + File() *os.File + + // Read one or more packets from the Device (without any additional headers). + // On a successful read it returns the number of packets read, and sets + // packet lengths within the sizes slice. len(sizes) must be >= len(buffs). + // A nonzero offset can be used to instruct the Device on where to begin + // reading into each element of the buffs slice. + Read(buffs [][]byte, sizes []int, offset int) (n int, err error) + + // Write one or more packets to the device (without any additional headers). + // On a successful write it returns the number of packets written. A nonzero + // offset can be used to instruct the Device on where to begin writing from + // each packet contained within the buffs slice. + Write(buffs [][]byte, offset int) (int, error) + + // MTU returns the MTU of the Device. + MTU() (int, error) + + // Name returns the current name of the Device. + Name() (string, error) + + // Events returns a channel of type Event, which is fed Device events. + Events() <-chan Event + + // Close stops the Device and closes the Event channel. + Close() error + + // BatchSize returns the preferred/max number of packets that can be read or + // written in a single read/write call. BatchSize must not change over the + // lifetime of a Device. + BatchSize() int } diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index 7411a69..b927e6f 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -8,6 +8,7 @@ package tun import ( "errors" "fmt" + "io" "net" "os" "sync" @@ -15,7 +16,6 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) @@ -33,7 +33,7 @@ type NativeTun struct { func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { for i := 0; i < 20; i++ { iface, err = net.InterfaceByIndex(index) - if err != nil && errors.Is(err, syscall.ENOMEM) { + if err != nil && errors.Is(err, unix.ENOMEM) { time.Sleep(time.Duration(i) * time.Second / 3) continue } @@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR { goto retry } tun.errors <- err @@ -217,45 +217,46 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + // TODO: the BSDs look very similar in Read() and Write(). They should be + // collapsed, with platform-specific files containing the varying parts of + // their implementations. select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] + buff := buffs[0][offset-4:] n, err := tun.tunFile.Read(buff[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range buffs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(buffs), nil } func (tun *NativeTun) Close() error { @@ -318,6 +319,10 @@ func (tun *NativeTun) MTU() (int, error) { return int(ifr.MTU), nil } +func (tun *NativeTun) BatchSize() int { + return 1 +} + func socketCloexec(family, sotype, proto int) (fd int, err error) { // See go/src/net/sys_cloexec.go for background. syscall.ForkLock.RLock() diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 42431aa..0783f74 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -333,45 +333,46 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] + buff := buffs[0][offset-4:] n, err := tun.tunFile.Read(buff[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { if offset < 4 { return 0, io.ErrShortBuffer } - buf = buf[offset-4:] - if len(buf) < 5 { - return 0, io.ErrShortBuffer - } - buf[0] = 0x00 - buf[1] = 0x00 - buf[2] = 0x00 - switch buf[4] >> 4 { - case 4: - buf[3] = unix.AF_INET - case 6: - buf[3] = unix.AF_INET6 - default: - return 0, unix.EAFNOSUPPORT + for i, buf := range buffs { + buf = buf[offset-4:] + if len(buf) < 5 { + return i, io.ErrShortBuffer + } + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } } - return tun.tunFile.Write(buf) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + return len(buffs), nil } func (tun *NativeTun) Close() error { @@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) { } return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 25dbc07..21984ca 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -323,12 +323,13 @@ func (tun *NativeTun) nameSlow() (string, error) { return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { +func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) { + var buf []byte if tun.nopi { - buf = buf[offset:] + buf = buffs[0][offset:] } else { // reserve space for header - buf = buf[offset-4:] + buf = buffs[0][offset-4:] // add packet information header buf[0] = 0x00 @@ -342,34 +343,36 @@ func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { } } - n, err := tun.tunFile.Write(buf) + _, err = tun.tunFile.Write(buf) if errors.Is(err, syscall.EBADFD) { err = os.ErrClosed + } else if err == nil { + n = 1 } return n, err } -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil -} - -func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { select { case err = <-tun.errors: default: if tun.nopi { - n, err = tun.tunFile.Read(buf[offset:]) + sizes[0], err = tun.tunFile.Read(buffs[0][offset:]) + if err == nil { + n = 1 + } } else { - buff := buf[offset-4:] - n, err = tun.tunFile.Read(buff[:]) + buff := buffs[0][offset-4:] + sizes[0], err = tun.tunFile.Read(buff[:]) if errors.Is(err, syscall.EBADFD) { err = os.ErrClosed + } else if err == nil { + n = 1 } - if n < 4 { - n = 0 + if sizes[0] < 4 { + sizes[0] = 0 } else { - n -= 4 + sizes[0] -= 4 } } } @@ -399,6 +402,10 @@ func (tun *NativeTun) Close() error { return err2 } +func (tun *NativeTun) BatchSize() int { + return 1 +} + func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index e7fd79c..210830c 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -8,13 +8,13 @@ package tun import ( "errors" "fmt" + "io" "net" "os" "sync" "syscall" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) @@ -204,45 +204,43 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] + buff := buffs[0][offset-4:] n, err := tun.tunFile.Read(buff[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range buffs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(buffs), nil } func (tun *NativeTun) Close() error { @@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) { return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index d5abb14..320dd59 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -15,7 +15,6 @@ import ( _ "unsafe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wintun" ) @@ -44,6 +43,7 @@ type NativeTun struct { closeOnce sync.Once close atomic.Bool forcedMTU int + outSizes []int } var ( @@ -134,9 +134,14 @@ func (tun *NativeTun) ForceMTU(mtu int) { } } +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 +} + // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -153,10 +158,11 @@ retry: switch err { case nil: packetSize := len(packet) - copy(buff[offset:], packet) + copy(buffs[0][offset:], packet) + sizes[0] = packetSize tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) - return packetSize, nil + return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.readWait, windows.INFINITE) @@ -173,33 +179,33 @@ retry: } } -func (tun *NativeTun) Flush() error { - return nil -} - -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() if tun.close.Load() { return 0, os.ErrClosed } - packetSize := len(buff) - offset - tun.rate.update(uint64(packetSize)) + for i, buff := range buffs { + packetSize := len(buff) - offset + tun.rate.update(uint64(packetSize)) - packet, err := tun.session.AllocateSendPacket(packetSize) - if err == nil { - copy(packet, buff[offset:]) - tun.session.SendPacket(packet) - return packetSize, nil - } - switch err { - case windows.ERROR_HANDLE_EOF: - return 0, os.ErrClosed - case windows.ERROR_BUFFER_OVERFLOW: - return 0, nil // Dropping when ring is full. + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buff[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return 0, fmt.Errorf("Write failed: %w", err) + return len(buffs), nil } // LUID returns Windows interface instance ID. diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index b143c76..d07e860 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -110,35 +110,42 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(data []byte, offset int) (int, error) { +func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: - return copy(data[offset:], msg), nil + n := copy(packets[0][offset:], msg) + sizes[0] = n + return 1, nil } } // Write is called by the wireguard device to deliver a packet for routing. -func (t *chTun) Write(data []byte, offset int) (int, error) { +func (t *chTun) Write(packets [][]byte, offset int) (int, error) { if offset == -1 { close(t.c.closed) close(t.c.events) return 0, io.EOF } - msg := make([]byte, len(data)-offset) - copy(msg, data[offset:]) - select { - case <-t.c.closed: - return 0, os.ErrClosed - case t.c.Inbound <- msg: - return len(data) - offset, nil + for i, data := range packets { + msg := make([]byte, len(data)-offset) + copy(msg, data[offset:]) + select { + case <-t.c.closed: + return i, os.ErrClosed + case t.c.Inbound <- msg: + } } + return len(packets), nil +} + +func (t *chTun) BatchSize() int { + return 1 } const DefaultMTU = 1420 -func (t *chTun) Flush() error { return nil } func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } func (t *chTun) Events() <-chan tun.Event { return t.c.events } |