From 5ba84696e29c6109e84b1f48247ae02a2bcb106e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 20 Apr 2018 04:05:11 +0200 Subject: Rework sticky sockets --- conn_linux.go | 336 +++++++++++++++++++++++---------------------------- syscall_linux.go | 30 ----- syscall_linux_386.go | 53 -------- 3 files changed, 150 insertions(+), 269 deletions(-) delete mode 100644 syscall_linux.go delete mode 100644 syscall_linux_386.go diff --git a/conn_linux.go b/conn_linux.go index 8b60d65..88b9ef4 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -1,13 +1,18 @@ -/* Copyright 2017 Jason A. Donenfeld . All Rights Reserved. +/* Copyright 2017-2018 Jason A. Donenfeld . All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. */ package main import ( - "encoding/binary" "errors" "golang.org/x/sys/unix" "net" @@ -15,41 +20,46 @@ import ( "unsafe" ) -/* Supports source address caching - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. - */ +type IPv4Source struct { + src [4]byte + ifindex int32 +} + +type IPv6Source struct { + src [16]byte + //ifindex belongs in dst.ZoneId +} + type NativeEndpoint struct { - src unix.RawSockaddrInet6 - dst unix.RawSockaddrInet6 + dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte + src [unsafe.Sizeof(IPv6Source{})]byte + isV6 bool } -type NativeBind struct { - sock4 int - sock6 int +func (endpoint *NativeEndpoint) src4() *IPv4Source { + return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) } -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = NativeBind{} +func (endpoint *NativeEndpoint) src6() *IPv6Source { + return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) +} -type IPv4Source struct { - src unix.RawSockaddrInet4 - Ifindex int32 +func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { + return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) } -func htons(val uint16) uint16 { - var out [unsafe.Sizeof(val)]byte - binary.BigEndian.PutUint16(out[:], val) - return *((*uint16)(unsafe.Pointer(&out[0]))) +func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { + return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) } -func ntohs(val uint16) uint16 { - tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val))) - return binary.BigEndian.Uint16((*tmp)[:]) +type NativeBind struct { + sock4 int + sock6 int } +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = NativeBind{} + func CreateEndpoint(s string) (Endpoint, error) { var end NativeEndpoint addr, err := parseEndpoint(s) @@ -59,10 +69,9 @@ func CreateEndpoint(s string) (Endpoint, error) { ipv4 := addr.IP.To4() if ipv4 != nil { - dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - dst.Family = unix.AF_INET - dst.Port = htons(uint16(addr.Port)) - dst.Zero = [8]byte{} + dst := end.dst4() + end.isV6 = false + dst.Port = addr.Port copy(dst.Addr[:], ipv4) end.ClearSrc() return &end, nil @@ -74,17 +83,16 @@ func CreateEndpoint(s string) (Endpoint, error) { if err != nil { return nil, err } - dst := &end.dst - dst.Family = unix.AF_INET6 - dst.Port = htons(uint16(addr.Port)) - dst.Flowinfo = 0 - dst.Scope_id = zone + dst := end.dst6() + end.isV6 = true + dst.Port = addr.Port + dst.ZoneId = zone copy(dst.Addr[:], ipv6[:]) end.ClearSrc() return &end, nil } - return nil, errors.New("Failed to recognize IP address format") + return nil, errors.New("Invalid IP address") } func CreateBind(port uint16) (Bind, uint16, error) { @@ -160,86 +168,85 @@ func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind NativeBind) Send(buff []byte, end Endpoint) error { nend := end.(*NativeEndpoint) - switch nend.dst.Family { - case unix.AF_INET6: - return send6(bind.sock6, nend, buff) - case unix.AF_INET: + if !nend.isV6 { return send4(bind.sock4, nend, buff) - default: - return errors.New("Unknown address family of destination") + } else { + return send6(bind.sock6, nend, buff) } } -func sockaddrToString(addr unix.RawSockaddrInet6) string { - var udpAddr net.UDPAddr - - switch addr.Family { - case unix.AF_INET6: - udpAddr.Port = int(ntohs(addr.Port)) - udpAddr.IP = addr.Addr[:] - return udpAddr.String() - - case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - udpAddr.Port = int(ntohs(ptr.Port)) - udpAddr.IP = net.IPv4( - ptr.Addr[0], - ptr.Addr[1], - ptr.Addr[2], - ptr.Addr[3], - ) - return udpAddr.String() +func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP { + return net.IPv4( + addr.Addr[0], + addr.Addr[1], + addr.Addr[2], + addr.Addr[3], + ) +} - default: - return "" - } +func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP { + return addr.Addr[:] } -func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { - switch addr.Family { - case unix.AF_INET6: - return addr.Addr[:] - case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) +func (end *NativeEndpoint) SrcIP() net.IP { + if !end.isV6 { return net.IPv4( - ptr.Addr[0], - ptr.Addr[1], - ptr.Addr[2], - ptr.Addr[3], + end.src4().src[0], + end.src4().src[1], + end.src4().src[2], + end.src4().src[3], ) - default: - return nil + } else { + return end.src6().src[:] } } -func (end *NativeEndpoint) SrcIP() net.IP { - return rawAddrToIP(end.src) -} - func (end *NativeEndpoint) DstIP() net.IP { - return rawAddrToIP(end.dst) + if !end.isV6 { + return net.IPv4( + end.dst4().Addr[0], + end.dst4().Addr[1], + end.dst4().Addr[2], + end.dst4().Addr[3], + ) + } else { + return end.dst6().Addr[:] + } } func (end *NativeEndpoint) DstToBytes() []byte { - ptr := unsafe.Pointer(&end.src) - arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) - return arr[:] + if !end.isV6 { + return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] + } else { + return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] + } } func (end *NativeEndpoint) SrcToString() string { - return sockaddrToString(end.src) + return end.SrcIP().String() } func (end *NativeEndpoint) DstToString() string { - return sockaddrToString(end.dst) + var udpAddr net.UDPAddr + udpAddr.IP = end.DstIP() + if !end.isV6 { + udpAddr.Port = end.dst4().Port + } else { + udpAddr.Port = end.dst6().Port + } + return udpAddr.String() } func (end *NativeEndpoint) ClearDst() { - end.dst = unix.RawSockaddrInet6{} + for i := range end.dst { + end.dst[i] = 0 + } } func (end *NativeEndpoint) ClearSrc() { - end.src = unix.RawSockaddrInet6{} + for i := range end.src { + end.src[i] = 0 + } } func zoneToUint32(zone string) (uint32, error) { @@ -295,6 +302,7 @@ func create4(port uint16) (int, uint16, error) { return unix.Bind(fd, &addr) }(); err != nil { unix.Close(fd) + return -1, 0, err } return fd, uint16(addr.Port), err @@ -353,140 +361,106 @@ func create6(port uint16) (int, uint16, error) { }(); err != nil { unix.Close(fd) + return -1, 0, err } return fd, uint16(addr.Port), err } -func send6(sock int, end *NativeEndpoint, buff []byte) error { +func send4(sock int, end *NativeEndpoint, buff []byte) error { // construct message header - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - cmsg := struct { cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo + pktinfo unix.Inet4Pktinfo }{ unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, - unix.Inet6Pktinfo{ - Addr: end.src.Addr, - Ifindex: end.src.Scope_id, + unix.Inet4Pktinfo{ + Spec_dst: end.src4().src, + Ifindex: end.src4().ifindex, }, } - msghdr := unix.Msghdr{ - Iov: &iovec, - Iovlen: 1, - Name: (*byte)(unsafe.Pointer(&end.dst)), - Namelen: unix.SizeofSockaddrInet6, - Control: (*byte)(unsafe.Pointer(&cmsg)), - } - - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - _, _, errno := sendmsg(sock, &msghdr, 0) + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - if errno == 0 { + if err == nil { return nil } // clear src and retry - if errno == unix.EINVAL { + if err == unix.EINVAL { end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - _, _, errno = sendmsg(sock, &msghdr, 0) + cmsg.pktinfo = unix.Inet4Pktinfo{} + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) } - return errno + return err } -func send4(sock int, end *NativeEndpoint, buff []byte) error { +func send6(sock int, end *NativeEndpoint, buff []byte) error { // construct message header - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) - cmsg := struct { cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo + pktinfo unix.Inet6Pktinfo }{ unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, }, - unix.Inet4Pktinfo{ - Spec_dst: src4.src.Addr, - Ifindex: src4.Ifindex, + unix.Inet6Pktinfo{ + Addr: end.src6().src, + Ifindex: end.dst6().ZoneId, }, } - msghdr := unix.Msghdr{ - Iov: &iovec, - Iovlen: 1, - Name: (*byte)(unsafe.Pointer(&end.dst)), - Namelen: unix.SizeofSockaddrInet4, - Control: (*byte)(unsafe.Pointer(&cmsg)), - Flags: 0, + if cmsg.pktinfo.Addr == [16]byte{} { + cmsg.pktinfo.Ifindex = 0 } - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - _, _, errno := sendmsg(sock, &msghdr, 0) - // clear source and try again + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - if errno == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - _, _, errno = sendmsg(sock, &msghdr, 0) + if err == nil { + return nil } - // errno = 0 is still an error instance + // clear src and retry - if errno == 0 { - return nil + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) } - return errno + return err } func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - var cmsg struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet4Pktinfo } - var msghdr unix.Msghdr - msghdr.Iov = &iovec - msghdr.Iovlen = 1 - msghdr.Name = (*byte)(unsafe.Pointer(&end.dst)) - msghdr.Namelen = unix.SizeofSockaddrInet4 - msghdr.Control = (*byte)(unsafe.Pointer(&cmsg)) - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - size, _, errno := recvmsg(sock, &msghdr, 0) + if err != nil { + return 0, err + } + end.isV6 = false - if errno != 0 { - return 0, errno + if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { + *end.dst4() = *newDst4 } // update source cache @@ -494,40 +468,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) - src4.src.Family = unix.AF_INET - src4.src.Addr = cmsg.pktinfo.Spec_dst - src4.Ifindex = cmsg.pktinfo.Ifindex + end.src4().src = cmsg.pktinfo.Spec_dst + end.src4().ifindex = cmsg.pktinfo.Ifindex } - return int(size), nil + return size, nil } func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - var cmsg struct { cmsghdr unix.Cmsghdr pktinfo unix.Inet6Pktinfo } - var msg unix.Msghdr - msg.Iov = &iovec - msg.Iovlen = 1 - msg.Name = (*byte)(unsafe.Pointer(&end.dst)) - msg.Namelen = uint32(unix.SizeofSockaddrInet6) - msg.Control = (*byte)(unsafe.Pointer(&cmsg)) - msg.SetControllen(int(unsafe.Sizeof(cmsg))) + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - size, _, errno := recvmsg(sock, &msg, 0) + if err != nil { + return 0, err + } + end.isV6 = true - if errno != 0 { - return 0, errno + if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { + *end.dst6() = *newDst6 } // update source cache @@ -535,10 +500,9 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src.Family = unix.AF_INET6 - end.src.Addr = cmsg.pktinfo.Addr - end.src.Scope_id = cmsg.pktinfo.Ifindex + end.src6().src = cmsg.pktinfo.Addr + end.dst6().ZoneId = cmsg.pktinfo.Ifindex } - return int(size), nil + return size, nil } diff --git a/syscall_linux.go b/syscall_linux.go deleted file mode 100644 index 3403544..0000000 --- a/syscall_linux.go +++ /dev/null @@ -1,30 +0,0 @@ -// +build linux,!386 - -/* Copyright 2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "golang.org/x/sys/unix" - "syscall" - "unsafe" -) - -func sendmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) { - return unix.Syscall( - unix.SYS_SENDMSG, - uintptr(fd), - uintptr(unsafe.Pointer(msghdr)), - uintptr(flags), - ) -} - -func recvmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) { - return unix.Syscall( - unix.SYS_RECVMSG, - uintptr(fd), - uintptr(unsafe.Pointer(msghdr)), - uintptr(flags), - ) -} diff --git a/syscall_linux_386.go b/syscall_linux_386.go deleted file mode 100644 index 76d7c7e..0000000 --- a/syscall_linux_386.go +++ /dev/null @@ -1,53 +0,0 @@ -// +build linux,386 - -/* Copyright 2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "golang.org/x/sys/unix" - "syscall" - "unsafe" -) - -const ( - _SENDMSG = 16 - _RECVMSG = 17 -) - -func sendmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) { - args := struct { - fd uintptr - msghdr uintptr - flags uintptr - }{ - uintptr(fd), - uintptr(unsafe.Pointer(msghdr)), - uintptr(flags), - } - return unix.Syscall( - unix.SYS_SOCKETCALL, - _SENDMSG, - uintptr(unsafe.Pointer(&args)), - 0, - ) -} - -func recvmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) { - args := struct { - fd uintptr - msghdr uintptr - flags uintptr - }{ - uintptr(fd), - uintptr(unsafe.Pointer(msghdr)), - uintptr(flags), - } - return unix.Syscall( - unix.SYS_SOCKETCALL, - _RECVMSG, - uintptr(unsafe.Pointer(&args)), - 0, - ) -} -- cgit v1.2.3