From 6a84778f2ca810f5fb5cb078e001494f08d9085f Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 13:53:07 -0700 Subject: conn, device: use UDP GSO and GRO on Linux StdNetBind probes for UDP GSO and GRO support at runtime. UDP GSO is dependent on checksum offload support on the egress netdev. UDP GSO will be disabled in the event sendmmsg() returns EIO, which is a strong signal that the egress netdev does not support checksum offload. The iperf3 results below demonstrate the effect of this commit between two Linux computers with i5-12400 CPUs. There is roughly ~13us of round trip latency between them. The first result is from commit 052af4a without UDP GSO or GRO. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 9.85 GBytes 8.46 Gbits/sec 1139 3.01 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 9.85 GBytes 8.46 Gbits/sec 1139 sender [ 5] 0.00-10.04 sec 9.85 GBytes 8.42 Gbits/sec receiver The second result is with UDP GSO and GRO. Starting Test: protocol: TCP, 1 streams, 131072 byte blocks [ ID] Interval Transfer Bitrate Retr Cwnd [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 3.15 MBytes - - - - - - - - - - - - - - - - - - - - - - - - - Test Complete. Summary Results: [ ID] Interval Transfer Bitrate Retr [ 5] 0.00-10.00 sec 12.3 GBytes 10.6 Gbits/sec 232 sender [ 5] 0.00-10.04 sec 12.3 GBytes 10.6 Gbits/sec receiver Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 399 ++++++++++++++++++++++++++++++--------------- conn/bind_std_test.go | 230 +++++++++++++++++++++++++- conn/control_default.go | 51 ++++++ conn/control_linux.go | 159 ++++++++++++++++++ conn/control_linux_test.go | 266 ++++++++++++++++++++++++++++++ conn/controlfns_linux.go | 8 + conn/errors_default.go | 12 ++ conn/errors_linux.go | 26 +++ conn/features_default.go | 15 ++ conn/features_linux.go | 35 ++++ conn/sticky_default.go | 41 ----- conn/sticky_linux.go | 110 ------------- conn/sticky_linux_test.go | 266 ------------------------------ device/send.go | 8 + go.mod | 2 +- go.sum | 4 +- 16 files changed, 1079 insertions(+), 553 deletions(-) create mode 100644 conn/control_default.go create mode 100644 conn/control_linux.go create mode 100644 conn/control_linux_test.go create mode 100644 conn/errors_default.go create mode 100644 conn/errors_linux.go create mode 100644 conn/features_default.go create mode 100644 conn/features_linux.go delete mode 100644 conn/sticky_default.go delete mode 100644 conn/sticky_linux.go delete mode 100644 conn/sticky_linux_test.go diff --git a/conn/bind_std.go b/conn/bind_std.go index c701ef8..9886c91 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -29,16 +30,19 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux - - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool @@ -54,23 +58,14 @@ func NewStdNetBind() Bind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, controlSize) } return &msgs }, @@ -113,7 +108,7 @@ func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } -// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx. +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() @@ -179,19 +174,21 @@ again: } var fns []ReceiveFunc if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) if runtime.GOOS == "linux" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) if runtime.GOOS == "linux" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { @@ -201,69 +198,93 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue } - return numMsgs, nil + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep } + return numMsgs, nil } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } @@ -293,28 +314,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -324,25 +359,56 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] } else { - return s.send4(conn, pc4, endpoint, bufs) + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error @@ -350,59 +416,128 @@ func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, ) if runtime.GOOS == "linux" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { - break + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 1e46776..34a3c9a 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,6 +1,12 @@ package conn -import "testing" +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) @@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { fn(bufs, sizes, eps) } } + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/control_default.go b/conn/control_default.go new file mode 100644 index 0000000..9459da5 --- /dev/null +++ b/conn/control_default.go @@ -0,0 +1,51 @@ +//go:build !linux || android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const controlSize = 0 + +const StdNetSupportsStickySockets = false diff --git a/conn/control_linux.go b/conn/control_linux.go new file mode 100644 index 0000000..44a94e6 --- /dev/null +++ b/conn/control_linux.go @@ -0,0 +1,159 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + } +} + +// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address +// and source ifindex found in ep. control's len will be set to 0 in the event +// that ep is a default value. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + if cap(*control) < len(ep.src) { + return + } + *control = (*control)[:0] + *control = append(*control, ep.src...) +} + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// controlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData) + +const StdNetSupportsStickySockets = true diff --git a/conn/control_linux_test.go b/conn/control_linux_test.go new file mode 100644 index 0000000..96f9da2 --- /dev/null +++ b/conn/control_linux_test.go @@ -0,0 +1,266 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) + + control := make([]byte, controlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + setSrc(ep, netip.MustParseAddr("::1"), 5) + + control := make([]byte, controlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, controlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, controlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, controlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + var control []byte + ep := &StdNetEndpoint{} + setSrc(ep, netip.MustParseAddr("::1"), 5) + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("Multiple", func(t *testing.T) { + zeroControl := make([]byte, unix.CmsgSpace(0)) + zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) + zeroHdr.SetLen(unix.CmsgLen(0)) + + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + combined := make([]byte, 0) + combined = append(combined, zeroControl...) + combined = append(combined, control...) + + ep := &StdNetEndpoint{} + getSrcFromControl(combined, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +} diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe..f6ab1d2 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -57,5 +57,13 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, ) } diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 0000000..f1e5b90 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 0000000..8e61000 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 0000000..d53ff5f --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 0000000..e1fb57f --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,35 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + if errSyscall != nil { + return + } + txOffload = true + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + if errSyscall != nil { + return + } + rxOffload = opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/conn/sticky_default.go b/conn/sticky_default.go deleted file mode 100644 index 1fa8a0c..0000000 --- a/conn/sticky_default.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build !linux || android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import "net/netip" - -func (e *StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - return 0 -} - -func (e *StdNetEndpoint) SrcToString() string { - return "" -} - -// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but -// use alternatively named flags and need ports and require testing. - -// getSrcFromControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func getSrcFromControl(control []byte, ep *StdNetEndpoint) { -} - -// setSrcControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func setSrcControl(control *[]byte, ep *StdNetEndpoint) { -} - -// srcControlSize returns the recommended buffer size for pooling sticky control -// data. -const srcControlSize = 0 - -const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go deleted file mode 100644 index a30ccc7..0000000 --- a/conn/sticky_linux.go +++ /dev/null @@ -1,110 +0,0 @@ -//go:build linux && !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "net/netip" - "unsafe" - - "golang.org/x/sys/unix" -) - -func (e *StdNetEndpoint) SrcIP() netip.Addr { - switch len(e.src) { - case unix.CmsgSpace(unix.SizeofInet4Pktinfo): - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - return netip.AddrFrom4(info.Spec_dst) - case unix.CmsgSpace(unix.SizeofInet6Pktinfo): - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - // TODO: set zone. in order to do so we need to check if the address is - // link local, and if it is perform a syscall to turn the ifindex into a - // zone string because netip uses string zones. - return netip.AddrFrom16(info.Addr) - } - return netip.Addr{} -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - switch len(e.src) { - case unix.CmsgSpace(unix.SizeofInet4Pktinfo): - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - return info.Ifindex - case unix.CmsgSpace(unix.SizeofInet6Pktinfo): - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - return int32(info.Ifindex) - } - return 0 -} - -func (e *StdNetEndpoint) SrcToString() string { - return e.SrcIP().String() -} - -// getSrcFromControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func getSrcFromControl(control []byte, ep *StdNetEndpoint) { - ep.ClearSrc() - - var ( - hdr unix.Cmsghdr - data []byte - rem []byte = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) - if err != nil { - return - } - - if hdr.Level == unix.IPPROTO_IP && - hdr.Type == unix.IP_PKTINFO { - - if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { - ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - } - ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - - hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) - copy(ep.src, hdrBuf) - copy(ep.src[unix.CmsgLen(0):], data) - return - } - - if hdr.Level == unix.IPPROTO_IPV6 && - hdr.Type == unix.IPV6_PKTINFO { - - if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { - ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) - } - - ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - - hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) - copy(ep.src, hdrBuf) - copy(ep.src[unix.CmsgLen(0):], data) - return - } - } -} - -// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address -// and source ifindex found in ep. control's len will be set to 0 in the event -// that ep is a default value. -func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - if cap(*control) < len(ep.src) { - return - } - *control = (*control)[:0] - *control = append(*control, ep.src...) -} - -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) - -const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go deleted file mode 100644 index 679213a..0000000 --- a/conn/sticky_linux_test.go +++ /dev/null @@ -1,266 +0,0 @@ -//go:build linux && !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "context" - "net" - "net/netip" - "runtime" - "testing" - "unsafe" - - "golang.org/x/sys/unix" -) - -func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { - var buf []byte - if addr.Is4() { - buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - hdr := unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - } - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) - - info := unix.Inet4Pktinfo{ - Ifindex: ifidx, - Spec_dst: addr.As4(), - } - copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) - } else { - buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) - hdr := unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - } - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) - - info := unix.Inet6Pktinfo{ - Ifindex: uint32(ifidx), - Addr: addr.As16(), - } - copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) - } - - ep.src = buf -} - -func Test_setSrcControl(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - ep := &StdNetEndpoint{ - AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), - } - setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - - control := make([]byte, srcControlSize) - - setSrcControl(&control, ep) - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - if hdr.Level != unix.IPPROTO_IP { - t.Errorf("unexpected level: %d", hdr.Level) - } - if hdr.Type != unix.IP_PKTINFO { - t.Errorf("unexpected type: %d", hdr.Type) - } - if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { - t.Errorf("unexpected length: %d", hdr.Len) - } - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { - t.Errorf("unexpected address: %v", info.Spec_dst) - } - if info.Ifindex != 5 { - t.Errorf("unexpected ifindex: %d", info.Ifindex) - } - }) - - t.Run("IPv6", func(t *testing.T) { - ep := &StdNetEndpoint{ - AddrPort: netip.MustParseAddrPort("[::1]:1234"), - } - setSrc(ep, netip.MustParseAddr("::1"), 5) - - control := make([]byte, srcControlSize) - - setSrcControl(&control, ep) - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - if hdr.Level != unix.IPPROTO_IPV6 { - t.Errorf("unexpected level: %d", hdr.Level) - } - if hdr.Type != unix.IPV6_PKTINFO { - t.Errorf("unexpected type: %d", hdr.Type) - } - if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { - t.Errorf("unexpected length: %d", hdr.Len) - } - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - if info.Addr != ep.SrcIP().As16() { - t.Errorf("unexpected address: %v", info.Addr) - } - if info.Ifindex != 5 { - t.Errorf("unexpected ifindex: %d", info.Ifindex) - } - }) - - t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, unix.CmsgLen(0)) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = 1 - hdr.Type = 2 - hdr.Len = 3 - - setSrcControl(&control, &StdNetEndpoint{}) - - if len(control) != 0 { - t.Errorf("unexpected control: %v", control) - } - }) -} - -func Test_getSrcFromControl(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - info.Spec_dst = [4]byte{127, 0, 0, 1} - info.Ifindex = 5 - - ep := &StdNetEndpoint{} - getSrcFromControl(control, ep) - - if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 5 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) - t.Run("IPv6", func(t *testing.T) { - control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - info.Ifindex = 5 - - ep := &StdNetEndpoint{} - getSrcFromControl(control, ep) - - if ep.SrcIP() != netip.MustParseAddr("::1") { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 5 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) - t.Run("ClearOnEmpty", func(t *testing.T) { - var control []byte - ep := &StdNetEndpoint{} - setSrc(ep, netip.MustParseAddr("::1"), 5) - - getSrcFromControl(control, ep) - if ep.SrcIP().IsValid() { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 0 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) - t.Run("Multiple", func(t *testing.T) { - zeroControl := make([]byte, unix.CmsgSpace(0)) - zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) - zeroHdr.SetLen(unix.CmsgLen(0)) - - control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - info.Spec_dst = [4]byte{127, 0, 0, 1} - info.Ifindex = 5 - - combined := make([]byte, 0) - combined = append(combined, zeroControl...) - combined = append(combined, control...) - - ep := &StdNetEndpoint{} - getSrcFromControl(combined, ep) - - if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 5 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) -} - -func Test_listenConfig(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") - if err != nil { - t.Fatal(err) - } - defer conn.Close() - sc, err := conn.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - if runtime.GOOS == "linux" { - var i int - sc.Control(func(fd uintptr) { - i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) - }) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Error("IP_PKTINFO not set!") - } - } else { - t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) - } - }) - t.Run("IPv6", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") - if err != nil { - t.Fatal(err) - } - sc, err := conn.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - if runtime.GOOS == "linux" { - var i int - sc.Control(func(fd uintptr) { - i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) - }) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Error("IPV6_PKTINFO not set!") - } - } else { - t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) - } - }) -} diff --git a/device/send.go b/device/send.go index d22bf26..cd8a2a0 100644 --- a/device/send.go +++ b/device/send.go @@ -17,6 +17,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun" ) @@ -525,6 +526,13 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.PutOutboundElement(elem) } device.PutOutboundElementsSlice(elems) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue diff --git a/go.mod b/go.mod index c04e1bb..758dcde 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( golang.org/x/crypto v0.6.0 golang.org/x/net v0.7.0 - golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 + golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 ) diff --git a/go.sum b/go.sum index cfeaee6..fe4ca7e 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= -- cgit v1.2.3