diff options
Diffstat (limited to 'device')
-rw-r--r-- | device/channels.go | 30 | ||||
-rw-r--r-- | device/device.go | 27 | ||||
-rw-r--r-- | device/device_test.go | 69 | ||||
-rw-r--r-- | device/peer.go | 26 | ||||
-rw-r--r-- | device/pools.go | 32 | ||||
-rw-r--r-- | device/pools_test.go | 48 | ||||
-rw-r--r-- | device/receive.go | 320 | ||||
-rw-r--r-- | device/send.go | 270 |
8 files changed, 562 insertions, 260 deletions
diff --git a/device/channels.go b/device/channels.go index 1bfeeaf..039d8df 100644 --- a/device/channels.go +++ b/device/channels.go @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *QueueInboundElement + c chan *[]*QueueInboundElement } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *[]*QueueInboundElement, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elem := <-q.c: - elem.Lock() - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) + case elems := <-q.c: + for _, elem := range *elems { + elem.Lock() + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsSlice(elems) default: return } @@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *QueueOutboundElement + c chan *[]*QueueOutboundElement } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elem := <-q.c: - elem.Lock() - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + case elems := <-q.c: + for _, elem := range *elems { + elem.Lock() + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsSlice(elems) default: return } diff --git a/device/device.go b/device/device.go index 3368a93..091c8d4 100644 --- a/device/device.go +++ b/device/device.go @@ -68,9 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + outboundElementsSlice *WaitPool + inboundElementsSlice *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { @@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() device.indexTable.Init() + device.PopulatePools() // create queues @@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { return device } +// BatchSize returns the BatchSize for the device as a whole which is the max of +// the bind batch size and the tun batch size. The batch size reported by device +// is the size used to construct memory pools, and is the allowed batch size for +// the lifetime of the device. +func (device *Device) BatchSize() int { + size := device.net.bind.BatchSize() + dSize := device.tun.device.BatchSize() + if size < dSize { + size = dSize + } + return size +} + func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() @@ -472,11 +488,13 @@ func (device *Device) BindUpdate() error { var err error var recvFns []conn.ReceiveFunc netc := &device.net + recvFns, netc.port, err = netc.bind.Open(netc.port) if err != nil { netc.port = 0 return err } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) if err != nil { netc.bind.Close() @@ -507,8 +525,9 @@ func (device *Device) BindUpdate() error { device.net.stopping.Add(len(recvFns)) device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + batchSize := netc.bind.BatchSize() for _, fn := range recvFns { - go device.RoutineReceiveIncoming(fn) + go device.RoutineReceiveIncoming(batchSize, fn) } device.log.Verbosef("UDP bind has been updated") diff --git a/device/device_test.go b/device/device_test.go index 975da64..73891bf 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -12,6 +12,7 @@ import ( "io" "math/rand" "net/netip" + "os" "runtime" "runtime/pprof" "sync" @@ -21,6 +22,7 @@ import ( "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" ) @@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) { } }) + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() + } + } + }) + close(done) } @@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) { t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) }) } + +type fakeBindSized struct { + size int +} + +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil +} +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(buffs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } + +type fakeTUNDeviceSized struct { + size int +} + +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(buffs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } +} diff --git a/device/peer.go b/device/peer.go index 0e7b669..0ac4896 100644 --- a/device/peer.go +++ b/device/peer.go @@ -45,9 +45,9 @@ type Peer struct { } queue struct { - staged chan *QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *[]*QueueOutboundElement // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] @@ -108,7 +108,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -123,9 +123,13 @@ func (peer *Peer) SendBuffer(buffer []byte) error { return errors.New("no known endpoint for peer") } - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffers, peer.endpoint) if err == nil { - peer.txBytes.Add(uint64(len(buffer))) + var totalLen uint64 + for _, b := range buffers { + totalLen += uint64(len(b)) + } + peer.txBytes.Add(totalLen) } return err } @@ -187,8 +191,12 @@ func (peer *Peer) Start() { device.flushInboundQueue(peer.queue.inbound) device.flushOutboundQueue(peer.queue.outbound) - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() + + // Use the device batch size, not the bind batch size, as the device size is + // the size of the batch pools. + batchSize := peer.device.BatchSize() + go peer.RoutineSequentialSender(batchSize) + go peer.RoutineSequentialReceiver(batchSize) peer.isRunning.Store(true) } diff --git a/device/pools.go b/device/pools.go index 239757f..02a5d6a 100644 --- a/device/pools.go +++ b/device/pools.go @@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) { } func (device *Device) PopulatePools() { + device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &s + }) + device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &s + }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) }) @@ -57,6 +65,30 @@ func (device *Device) PopulatePools() { }) } +func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { + return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +} + +func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { + for i := range *s { + (*s)[i] = nil + } + *s = (*s)[:0] + device.pool.outboundElementsSlice.Put(s) +} + +func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { + return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +} + +func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { + for i := range *s { + (*s)[i] = nil + } + *s = (*s)[:0] + device.pool.inboundElementsSlice.Put(s) +} + func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) } diff --git a/device/pools_test.go b/device/pools_test.go index 1502a29..82d7493 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) { } wg.Wait() } + +func BenchmarkWaitPoolEmpty(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(0, func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkSyncPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := sync.Pool{New: func() any { return make([]byte, 16) }} + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} diff --git a/device/receive.go b/device/receive.go index 03fcf00..aee7864 100644 --- a/device/receive.go +++ b/device/receive.go @@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { recvName := recv.PrettyName() defer func() { device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) @@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( + buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + buffs = make([][]byte, maxBatchSize) err error - size int - endpoint conn.Endpoint + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int + elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) ) - for { - size, endpoint, err = recv(buffer[:]) + for i := range buffsArrs { + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] + } + + defer func() { + for i := 0; i < maxBatchSize; i++ { + if buffsArrs[i] != nil { + device.PutMessageBuffer(buffsArrs[i]) + } + } + }() + for { + count, err = recv(buffs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) if errors.Is(err, net.ErrClosed) { return } @@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { if deathSpiral < 10 { deathSpiral++ time.Sleep(time.Second / 3) - buffer = device.GetMessageBuffer() continue } return } deathSpiral = 0 - if size < MinMessageSize { - continue - } + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - // check size of packet + // check size of packet - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + packet := buffsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - var okay bool + switch msgType { - switch msgType { + // check if transport - // check if transport + case MessageTransportType: - case MessageTransportType: + // check size - // check size + if len(packet) < MessageTransportSize { + continue + } - if len(packet) < MessageTransportSize { - continue - } + // lookup key pair - // lookup key pair + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + // check keypair expiry - // check keypair expiry + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + elem.Mutex = sync.Mutex{} + elem.Lock() + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsSlice() + elemsByPeer[peer] = elemsForPeer + } + *elemsForPeer = append(*elemsForPeer, elem) + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] continue - } - - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - // add to decryption queues - if peer.isRunning.Load() { - peer.queue.inbound.c <- elem - device.queue.decryption.c <- elem - buffer = device.GetMessageBuffer() - } else { - device.PutInboundElement(elem) - } - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - default: - device.log.Verbosef("Received message with unknown type") - } + default: + device.log.Verbosef("Received message with unknown type") + continue + } - if okay { select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: buffer, + buffer: buffsArrs[i], packet: packet, - endpoint: endpoint, + endpoint: endpoints[i], }: - buffer = device.GetMessageBuffer() + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] default: } } + for peer, elems := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elems + for _, elem := range *elems { + device.queue.decryption.c <- elem + } + } else { + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsSlice(elems) + } + delete(elemsByPeer, peer) + } } } @@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) { } } -func (peer *Peer) RoutineSequentialReceiver() { +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device defer func() { device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) @@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - for elem := range peer.queue.inbound.c { - if elem == nil { + buffs := make([][]byte, 0, maxBatchSize) + + for elems := range peer.queue.inbound.c { + if elems == nil { return } - var err error - elem.Lock() - if elem.packet == nil { - // decryption failed - goto skip - } + for _, elem := range *elems { + elem.Lock() + if elem.packet == nil { + // decryption failed + continue + } - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - goto skip - } + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } - peer.SetEndpointFromPacket(elem.endpoint) - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - peer.SendStagedPackets() - } + peer.SetEndpointFromPacket(elem.endpoint) + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + peer.SendStagedPackets() + } + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) + continue + } + peer.timersDataReceived() - if len(elem.packet) == 0 { - device.log.Verbosef("%v - Receiving keepalive packet", peer) - goto skip - } - peer.timersDataReceived() + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - goto skip - } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - goto skip - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.Lookup(src) != peer { - device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) - goto skip - } + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - goto skip - } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - goto skip - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.Lookup(src) != peer { - device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) - goto skip + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) + continue } - default: - device.log.Verbosef("Packet with invalid IP version from %v", peer) - goto skip + buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - - _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) - if err != nil && !device.isClosed() { - device.log.Errorf("Failed to write packet to TUN device: %v", err) - } - if len(peer.queue.inbound.c) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Errorf("Unable to flush packets: %v", err) + if len(buffs) > 0 { + _, err := device.tun.device.Write(buffs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - skip: - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + buffs = buffs[:0] + device.PutInboundElementsSlice(elems) } } diff --git a/device/send.go b/device/send.go index 854d172..b33b9f4 100644 --- a/device/send.go +++ b/device/send.go @@ -17,6 +17,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -77,12 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() + elems := peer.device.GetOutboundElementsSlice() + *elems = append(*elems, elem) select { - case peer.queue.staged <- elem: + case peer.queue.staged <- elems: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) + peer.device.PutOutboundElementsSlice(elems) } } peer.SendStagedPackets() @@ -125,7 +129,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -163,7 +167,8 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + // TODO: allocation could be avoided + err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -183,7 +188,8 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) var buff [MessageCookieReplySize]byte writer := bytes.NewBuffer(buff[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + // TODO: allocation could be avoided + device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } @@ -198,11 +204,6 @@ func (peer *Peer) keepKeyFreshSending() { } } -/* Reads packets from the TUN and inserts - * into staged queue for peer - * - * Obs. Single instance per TUN device - */ func (device *Device) RoutineReadFromTUN() { defer func() { device.log.Verbosef("Routine: TUN reader - stopped") @@ -212,81 +213,123 @@ func (device *Device) RoutineReadFromTUN() { device.log.Verbosef("Routine: TUN reader - started") - var elem *QueueOutboundElement + var ( + batchSize = device.BatchSize() + readErr error + elems = make([]*QueueOutboundElement, batchSize) + buffs = make([][]byte, batchSize) + elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + count = 0 + sizes = make([]int, batchSize) + offset = MessageTransportHeaderSize + ) + + for i := range elems { + elems[i] = device.NewOutboundElement() + buffs[i] = elems[i].buffer[:] + } - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + defer func() { + for _, elem := range elems { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } } - elem = device.NewOutboundElement() - - // read packet + }() - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) - if err != nil { - if !device.isClosed() { - if !errors.Is(err, os.ErrClosed) { - device.log.Errorf("Failed to read packet from TUN device: %v", err) - } - go device.Close() + for { + // read packets + count, readErr = device.tun.device.Read(buffs, sizes, offset) + for i := 0; i < count; i++ { + if sizes[i] < 1 { + continue } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - if size == 0 || size > MaxContentSize { - continue - } + elem := elems[i] + elem.packet = buffs[i][offset : offset+sizes[i]] - elem.packet = elem.buffer[offset : offset+size] + // lookup peer + var peer *Peer + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.Lookup(dst) - // lookup peer + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.Lookup(dst) - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue + default: + device.log.Verbosef("Received packet with unknown IP version") } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.Lookup(dst) - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { + if peer == nil { continue } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.Lookup(dst) - - default: - device.log.Verbosef("Received packet with unknown IP version") + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetOutboundElementsSlice() + elemsByPeer[peer] = elemsForPeer + } + *elemsForPeer = append(*elemsForPeer, elem) + elems[i] = device.NewOutboundElement() + buffs[i] = elems[i].buffer[:] } - if peer == nil { - continue + for peer, elemsForPeer := range elemsByPeer { + if peer.isRunning.Load() { + peer.StagePackets(elemsForPeer) + peer.SendStagedPackets() + } else { + for _, elem := range *elemsForPeer { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsSlice(elemsForPeer) + } + delete(elemsByPeer, peer) } - if peer.isRunning.Load() { - peer.StagePacket(elem) - elem = nil - peer.SendStagedPackets() + + if readErr != nil { + if errors.Is(readErr, tun.ErrTooManySegments) { + // TODO: record stat for this + // This will happen if MSS is surprisingly small (< 576) + // coincident with reasonably high throughput. + device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + continue + } + if !device.isClosed() { + if !errors.Is(readErr, os.ErrClosed) { + device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + } + go device.Close() + } + return } } } -func (peer *Peer) StagePacket(elem *QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { for { select { - case peer.queue.staged <- elem: + case peer.queue.staged <- elems: return default: } select { case tooOld := <-peer.queue.staged: - peer.device.PutMessageBuffer(tooOld.buffer) - peer.device.PutOutboundElement(tooOld) + for _, elem := range *tooOld { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsSlice(tooOld) default: } } @@ -305,26 +348,55 @@ top: } for { + var elemsOOO *[]*QueueOutboundElement select { - case elem := <-peer.queue.staged: - elem.peer = peer - elem.nonce = keypair.sendNonce.Add(1) - 1 - if elem.nonce >= RejectAfterMessages { - keypair.sendNonce.Store(RejectAfterMessages) - peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans - goto top + case elems := <-peer.queue.staged: + i := 0 + for _, elem := range *elems { + elem.peer = peer + elem.nonce = keypair.sendNonce.Add(1) - 1 + if elem.nonce >= RejectAfterMessages { + keypair.sendNonce.Store(RejectAfterMessages) + if elemsOOO == nil { + elemsOOO = peer.device.GetOutboundElementsSlice() + } + *elemsOOO = append(*elemsOOO, elem) + continue + } else { + (*elems)[i] = elem + i++ + } + + elem.keypair = keypair + elem.Lock() } + *elems = (*elems)[:i] - elem.keypair = keypair - elem.Lock() + if elemsOOO != nil { + peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + } + + if len(*elems) == 0 { + peer.device.PutOutboundElementsSlice(elems) + goto top + } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elem - peer.device.queue.encryption.c <- elem + peer.queue.outbound.c <- elems + for _, elem := range *elems { + peer.device.queue.encryption.c <- elem + } } else { - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + for _, elem := range *elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsSlice(elems) + } + + if elemsOOO != nil { + goto top } default: return @@ -335,9 +407,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elem := <-peer.queue.staged: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + case elems := <-peer.queue.staged: + for _, elem := range *elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsSlice(elems) default: return } @@ -400,12 +475,7 @@ func (device *Device) RoutineEncryption(id int) { } } -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { +func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device defer func() { defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) @@ -413,36 +483,50 @@ func (peer *Peer) RoutineSequentialSender() { }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) - for elem := range peer.queue.outbound.c { - if elem == nil { + buffs := make([][]byte, 0, maxBatchSize) + + for elems := range peer.queue.outbound.c { + buffs = buffs[:0] + if elems == nil { return } - elem.Lock() if !peer.isRunning.Load() { // peer has been stopped; return re-usable elems to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. - // The timers and SendBuffer code are resilient to a few stragglers. + // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + for _, elem := range *elems { + elem.Lock() + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } continue } + dataSent := false + for _, elem := range *elems { + elem.Lock() + if len(elem.packet) != MessageKeepaliveSize { + dataSent = true + } + buffs = append(buffs, elem.packet) + } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - // send message and return buffer to pool - - err := peer.SendBuffer(elem.packet) - if len(elem.packet) != MessageKeepaliveSize { + err := peer.SendBuffers(buffs) + if dataSent { peer.timersDataSent() } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsSlice(elems) if err != nil { - device.log.Errorf("%v - Failed to send data packet: %v", peer, err) + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue } |