diff options
Diffstat (limited to 'conn')
-rw-r--r-- | conn/bind_linux.go | 587 | ||||
-rw-r--r-- | conn/bind_std.go | 339 | ||||
-rw-r--r-- | conn/boundif_android.go | 8 | ||||
-rw-r--r-- | conn/conn.go | 2 | ||||
-rw-r--r-- | conn/controlfns.go | 36 | ||||
-rw-r--r-- | conn/controlfns_linux.go | 41 | ||||
-rw-r--r-- | conn/controlfns_unix.go | 28 | ||||
-rw-r--r-- | conn/default.go | 2 | ||||
-rw-r--r-- | conn/mark_default.go | 2 | ||||
-rw-r--r-- | conn/mark_unix.go | 10 | ||||
-rw-r--r-- | conn/sticky_default.go | 26 | ||||
-rw-r--r-- | conn/sticky_linux.go | 111 | ||||
-rw-r--r-- | conn/sticky_linux_test.go | 207 |
13 files changed, 700 insertions, 699 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go deleted file mode 100644 index b6bc0dc..0000000 --- a/conn/bind_linux.go +++ /dev/null @@ -1,587 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "errors" - "net" - "net/netip" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -type ipv4Source struct { - Src [4]byte - Ifindex int32 -} - -type ipv6Source struct { - src [16]byte - // ifindex belongs in dst.ZoneId -} - -type LinuxSocketEndpoint struct { - mu sync.Mutex - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(ipv6Source{})]byte - isV6 bool -} - -func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() } -func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } -func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 } - -func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source { - return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source { - return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. -type LinuxSocketBind struct { - // mu guards sock4 and sock6 and the associated fds. - // As long as someone holds mu (read or write), the associated fds are valid. - mu sync.RWMutex - sock4 int - sock6 int -} - -func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } -func NewDefaultBind() Bind { return NewLinuxSocketBind() } - -var ( - _ Endpoint = (*LinuxSocketEndpoint)(nil) - _ Bind = (*LinuxSocketBind)(nil) -) - -func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { - var end LinuxSocketEndpoint - e, err := netip.ParseAddrPort(s) - if err != nil { - return nil, err - } - - if e.Addr().Is4() { - dst := end.dst4() - end.isV6 = false - dst.Port = int(e.Port()) - dst.Addr = e.Addr().As4() - end.ClearSrc() - return &end, nil - } - - if e.Addr().Is6() { - zone, err := zoneToUint32(e.Addr().Zone()) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = int(e.Port()) - dst.ZoneId = zone - dst.Addr = e.Addr().As16() - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("invalid IP address") -} - -func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() - - var err error - var newPort uint16 - var tries int - - if bind.sock4 != -1 || bind.sock6 != -1 { - return nil, 0, ErrBindAlreadyOpen - } - - originalPort := port - -again: - port = originalPort - var sock4, sock6 int - // Attempt ipv6 bind, update port if successful. - sock6, newPort, err = create6(port) - if err != nil { - if !errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, 0, err - } - } else { - port = newPort - } - - // Attempt ipv4 bind, update port if successful. - sock4, newPort, err = create4(port) - if err != nil { - if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - unix.Close(sock6) - tries++ - goto again - } - if !errors.Is(err, syscall.EAFNOSUPPORT) { - unix.Close(sock6) - return nil, 0, err - } - } else { - port = newPort - } - - var fns []ReceiveFunc - if sock4 != -1 { - bind.sock4 = sock4 - fns = append(fns, bind.receiveIPv4) - } - if sock6 != -1 { - bind.sock6 = sock6 - fns = append(fns, bind.receiveIPv6) - } - if len(fns) == 0 { - return nil, 0, syscall.EAFNOSUPPORT - } - return fns, port, nil -} - -func (bind *LinuxSocketBind) SetMark(value uint32) error { - bind.mu.RLock() - defer bind.mu.RUnlock() - - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - if err != nil { - return err - } - } - - return nil -} - -func (bind *LinuxSocketBind) BatchSize() int { - return 1 -} - -func (bind *LinuxSocketBind) Close() error { - // Take a readlock to shut down the sockets... - bind.mu.RLock() - if bind.sock6 != -1 { - unix.Shutdown(bind.sock6, unix.SHUT_RDWR) - } - if bind.sock4 != -1 { - unix.Shutdown(bind.sock4, unix.SHUT_RDWR) - } - bind.mu.RUnlock() - // ...and a write lock to close the fd. - // This ensures that no one else is using the fd. - bind.mu.Lock() - defer bind.mu.Unlock() - var err1, err2 error - if bind.sock6 != -1 { - err1 = unix.Close(bind.sock6) - bind.sock6 = -1 - } - if bind.sock4 != -1 { - err2 = unix.Close(bind.sock4) - bind.sock4 = -1 - } - - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { - bind.mu.RLock() - defer bind.mu.RUnlock() - if bind.sock4 == -1 { - return 0, net.ErrClosed - } - var end LinuxSocketEndpoint - n, err := receive4(bind.sock4, buffs[0], &end) - if err != nil { - return 0, err - } - eps[0] = &end - sizes[0] = n - return 1, nil -} - -func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { - bind.mu.RLock() - defer bind.mu.RUnlock() - if bind.sock6 == -1 { - return 0, net.ErrClosed - } - var end LinuxSocketEndpoint - n, err := receive6(bind.sock6, buffs[0], &end) - if err != nil { - return 0, err - } - eps[0] = &end - sizes[0] = n - return 1, nil -} - -func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error { - nend, ok := end.(*LinuxSocketEndpoint) - if !ok { - return ErrWrongEndpointType - } - bind.mu.RLock() - defer bind.mu.RUnlock() - if !nend.isV6 { - if bind.sock4 == -1 { - return net.ErrClosed - } - for _, buff := range buffs { - err := send4(bind.sock4, nend, buff) - if err != nil { - return err - } - } - } else { - if bind.sock6 == -1 { - return net.ErrClosed - } - for _, buff := range buffs { - err := send6(bind.sock6, nend, buff) - if err != nil { - return err - } - } - } - return nil -} - -func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { - if !end.isV6 { - return netip.AddrFrom4(end.src4().Src) - } else { - return netip.AddrFrom16(end.src6().src) - } -} - -func (end *LinuxSocketEndpoint) DstIP() netip.Addr { - if !end.isV6 { - return netip.AddrFrom4(end.dst4().Addr) - } else { - return netip.AddrFrom16(end.dst6().Addr) - } -} - -func (end *LinuxSocketEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *LinuxSocketEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *LinuxSocketEndpoint) DstToString() string { - var port int - if !end.isV6 { - port = end.dst4().Port - } else { - port = end.dst6().Port - } - return netip.AddrPortFrom(end.DstIP(), uint16(port)).String() -} - -func (end *LinuxSocketEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *LinuxSocketEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, - 0, - ) - if err != nil { - return -1, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return -1, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, - 0, - ) - if err != nil { - return -1, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return -1, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().Src, - Ifindex: end.src4().Ifindex, - }, - } - - end.mu.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.mu.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - end.mu.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.mu.Unlock() - } - - return err -} - -func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - end.mu.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.mu.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - end.mu.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.mu.Unlock() - } - - return err -} - -func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().Src = cmsg.pktinfo.Spec_dst - end.src4().Ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} diff --git a/conn/bind_std.go b/conn/bind_std.go index 98fe23c..a164f56 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -6,32 +6,91 @@ package conn import ( + "context" "errors" "net" "net/netip" + "strconv" "sync" "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) ) -// StdNetBind is meant to be a temporary solution on platforms for which -// the sticky socket / source caching behavior has not yet been implemented. -// It uses the Go's net package to implement networking. -// See LinuxSocketBind for a proper implementation on the Linux platform. +// StdNetBind implements Bind for all platforms except Windows. type StdNetBind struct { - mu sync.Mutex // protects following fields - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool + mu sync.Mutex // protects following fields + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + batchSize int + udpAddrPool sync.Pool + ipv4MsgsPool sync.Pool + ipv6MsgsPool sync.Pool } -func NewStdNetBind() Bind { return &StdNetBind{} } +func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) } + +func NewStdNetBindBatch(maxBatchSize int) Bind { + if maxBatchSize == 0 { + maxBatchSize = DefaultBatchSize + } + return &StdNetBind{ + batchSize: maxBatchSize, + + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, -type StdNetEndpoint netip.AddrPort + ipv4MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv4.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + + ipv6MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if supported. + src struct { + netip.Addr + ifidx int32 + } +} var ( _ Bind = (*StdNetBind)(nil) - _ Endpoint = StdNetEndpoint{} + _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { @@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { return asEndpoint(e), err } -func (StdNetEndpoint) ClearSrc() {} +func (e *StdNetEndpoint) ClearSrc() { + e.src.ifidx = 0 + e.src.Addr = netip.Addr{} +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} -func (e StdNetEndpoint) DstIP() netip.Addr { - return (netip.AddrPort)(e).Addr() +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return e.src.Addr } -func (e StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} // not supported +func (e *StdNetEndpoint) SrcIfidx() int32 { + return e.src.ifidx } -func (e StdNetEndpoint) DstToBytes() []byte { - b, _ := (netip.AddrPort)(e).MarshalBinary() +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() return b } -func (e StdNetEndpoint) DstToString() string { - return (netip.AddrPort)(e).String() +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() } -func (e StdNetEndpoint) SrcToString() string { - return "" +func (e *StdNetEndpoint) SrcToString() string { + return e.src.Addr.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { if err != nil { return nil, 0, err } - return conn, uaddr.Port, nil + return conn.(*net.UDPConn), uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() var err error var tries int - if bind.ipv4 != nil || bind.ipv6 != nil { + if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } @@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { // If uport is 0, we can retry on failure. again: port := int(uport) - var ipv4, ipv6 *net.UDPConn + var v4conn, v6conn *net.UDPConn - ipv4, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - ipv6, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - ipv4.Close() + v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - ipv4.Close() + v4conn.Close() return nil, 0, err } var fns []ReceiveFunc - if ipv4 != nil { - fns = append(fns, bind.makeReceiveIPv4(ipv4)) - bind.ipv4 = ipv4 + if v4conn != nil { + fns = append(fns, s.receiveIPv4) + s.ipv4 = v4conn } - if ipv6 != nil { - fns = append(fns, bind.makeReceiveIPv6(ipv6)) - bind.ipv6 = ipv6 + if v6conn != nil { + fns = append(fns, s.receiveIPv6) + s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } - return fns, uint16(port), nil -} -func (bind *StdNetBind) BatchSize() int { - return 1 -} + s.ipv4PC = ipv4.NewPacketConn(s.ipv4) + s.ipv6PC = ipv6.NewPacketConn(s.ipv6) -func (bind *StdNetBind) Close() error { - bind.mu.Lock() - defer bind.mu.Unlock() + return fns, uint16(port), nil +} - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - bind.ipv4 = nil +func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - bind.ipv6 = nil + numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err } - bind.blackhole4 = false - bind.blackhole6 = false - if err1 != nil { - return err1 + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep } - return err2 + return numMsgs, nil } -func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } +func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + defer s.ipv6MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0) + if err != nil { return 0, err } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } -func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } - return 0, err - } +func (s *StdNetBind) BatchSize() int { + return s.batchSize } -func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { - var err error - nend, ok := endpoint.(StdNetEndpoint) - if !ok { - return ErrWrongEndpointType +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err1, err2 error + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + } + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + } + s.blackhole4 = false + s.blackhole6 = false + if err1 != nil { + return err1 } - addrPort := netip.AddrPort(nend) + return err2 +} - bind.mu.Lock() - blackhole := bind.blackhole4 - conn := bind.ipv4 - if addrPort.Addr().Is6() { - blackhole = bind.blackhole6 - conn = bind.ipv6 +func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + is6 = true } - bind.mu.Unlock() + s.mu.Unlock() if blackhole { return nil @@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - for _, buff := range buffs { - _, err = conn.WriteToUDPAddrPort(buff, addrPort) - if err != nil { - return err + if is6 { + return s.send6(s.ipv6PC, endpoint, buffs) + } else { + return s.send4(s.ipv4PC, endpoint, buffs) + } +} + +func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as4 := ep.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv4MsgsPool.Put(msgs) + return err +} + +func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as16 := ep.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break } + start += n } - return nil + s.udpAddrPool.Put(ua) + s.ipv6MsgsPool.Put(msgs) + return err } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. @@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { // but Endpoints are immutable, so we can re-use them. var endpointPool = sync.Pool{ New: func() any { - return make(map[netip.AddrPort]Endpoint) + return make(map[netip.AddrPort]*StdNetEndpoint) }, } // asEndpoint returns an Endpoint containing ap. -func asEndpoint(ap netip.AddrPort) Endpoint { - m := endpointPool.Get().(map[netip.AddrPort]Endpoint) +func asEndpoint(ap netip.AddrPort) *StdNetEndpoint { + m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint) defer endpointPool.Put(m) e, ok := m[ap] if !ok { - e = Endpoint(StdNetEndpoint(ap)) + e = &StdNetEndpoint{AddrPort: ap} m[ap] = e } return e diff --git a/conn/boundif_android.go b/conn/boundif_android.go index 818e4e6..dd3ca5b 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -5,8 +5,8 @@ package conn -func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { + sysconn, err := s.ipv4.SyscallConn() if err != nil { return -1, err } @@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { return } -func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { + sysconn, err := s.ipv6.SyscallConn() if err != nil { return -1, err } diff --git a/conn/conn.go b/conn/conn.go index 8c0a827..9cbd0af 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -16,7 +16,7 @@ import ( ) const ( - DefaultBatchSize = 1 // maximum number of packets handled per read and write + DefaultBatchSize = 128 // maximum number of packets handled per read and write ) // A ReceiveFunc receives at least one packet from the network and writes them diff --git a/conn/controlfns.go b/conn/controlfns.go new file mode 100644 index 0000000..fe32871 --- /dev/null +++ b/conn/controlfns.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + "syscall" +) + +// controlFn is the callback function signature from net.ListenConfig.Control. +// It is used to apply platform specific configuration to the socket prior to +// bind. +type controlFn func(network, address string, c syscall.RawConn) error + +// controlFns is a list of functions that are called from the listen config +// that can apply socket options. +var controlFns = []controlFn{} + +// listenConfig returns a net.ListenConfig that applies the controlFns to the +// socket prior to bind. This is used to apply socket buffer sizing and packet +// information OOB configuration for sticky sockets. +func listenConfig() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + for _, fn := range controlFns { + if err := fn(network, address, c); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go new file mode 100644 index 0000000..9e26d95 --- /dev/null +++ b/conn/controlfns_linux.go @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + + // Enable receiving of the packet information (IP_PKTINFO for IPv4, + // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. + func(network, address string, c syscall.RawConn) error { + var err error + switch network { + case "udp4": + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + }) + case "udp6": + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + if err != nil { + return + } + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + default: + err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) + } + return err + }, + ) +} diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go new file mode 100644 index 0000000..9738c73 --- /dev/null +++ b/conn/controlfns_unix.go @@ -0,0 +1,28 @@ +//go:build !windows && !linux && !js + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + var err error + if network == "udp6" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + } + return err + }, + ) +} diff --git a/conn/default.go b/conn/default.go index c7b4a84..b6f761b 100644 --- a/conn/default.go +++ b/conn/default.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !windows /* SPDX-License-Identifier: MIT * diff --git a/conn/mark_default.go b/conn/mark_default.go index 9944c38..3102384 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -7,6 +7,6 @@ package conn -func (bind *StdNetBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { return nil } diff --git a/conn/mark_unix.go b/conn/mark_unix.go index 5566b28..d9e46ee 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -26,13 +26,13 @@ func init() { } } -func (bind *StdNetBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { var operr error if fwmarkIoctl == 0 { return nil } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() + if s.ipv4 != nil { + fd, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error { return err } } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() + if s.ipv6 != nil { + fd, err := s.ipv6.SyscallConn() if err != nil { return err } diff --git a/conn/sticky_default.go b/conn/sticky_default.go new file mode 100644 index 0000000..3ce9a56 --- /dev/null +++ b/conn/sticky_default.go @@ -0,0 +1,26 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but +// use alternatively named flags and need ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// srcControlSize returns the recommended buffer size for pooling sticky control +// data. +const srcControlSize = 0 diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go new file mode 100644 index 0000000..bf17839 --- /dev/null +++ b/conn/sticky_linux.go @@ -0,0 +1,111 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + info := pktInfoFromBuf[unix.Inet4Pktinfo](data) + ep.src.Addr = netip.AddrFrom4(info.Spec_dst) + ep.src.ifidx = info.Ifindex + + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + info := pktInfoFromBuf[unix.Inet6Pktinfo](data) + ep.src.Addr = netip.AddrFrom16(info.Addr) + ep.src.ifidx = int32(info.Ifindex) + + return + } + } +} + +// pktInfoFromBuf returns type T populated from the provided buf via copy(). It +// panics if buf is of insufficient size. +func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { + size := int(unsafe.Sizeof(t)) + if len(buf) < size { + panic("pktInfoFromBuf: buffer too small") + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) + return t +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + *control = (*control)[:cap(*control)] + if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { + *control = (*control)[:0] + return + } + + if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { + *control = (*control)[:0] + return + } + + if len(*control) < srcControlSize { + *control = (*control)[:0] + return + } + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) + if ep.SrcIP().Is4() { + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) + info.Ifindex = ep.src.ifidx + if ep.SrcIP().IsValid() { + info.Spec_dst = ep.SrcIP().As4() + } + } else { + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo + + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) + info.Ifindex = uint32(ep.src.ifidx) + if ep.SrcIP().IsValid() { + info.Addr = ep.SrcIP().As16() + } + } + + *control = (*control)[:hdr.Len] +} + +var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo) diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go new file mode 100644 index 0000000..a42c89e --- /dev/null +++ b/conn/sticky_linux_test.go @@ -0,0 +1,207 @@ +//go:build linux +// +build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + ep.src.Addr = netip.MustParseAddr("127.0.0.1") + ep.src.ifidx = 5 + + control := make([]byte, srcControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + ep.src.Addr = netip.MustParseAddr("::1") + ep.src.ifidx = 5 + + control := make([]byte, srcControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + control := make([]byte, srcControlSize) + ep := &StdNetEndpoint{} + ep.src.Addr = netip.MustParseAddr("::1") + ep.src.ifidx = 5 + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 0 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +} |