diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2017-09-24 21:35:25 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2017-09-24 21:35:25 +0200 |
commit | eefa47b0f91416a3435102f89339b3ec4fcdd672 (patch) | |
tree | 704e1f38f5e0cf8bba15607c16f2e2d1c23baedb | |
parent | c545d63bb93b8192dfdc7037952fc2661dd1222b (diff) |
Begin work on source address caching (linux)
-rw-r--r-- | src/conn.go | 22 | ||||
-rw-r--r-- | src/conn_linux.go | 243 | ||||
-rw-r--r-- | src/misc.go | 5 | ||||
-rw-r--r-- | src/tun_linux.go | 11 | ||||
-rw-r--r-- | src/uapi.go | 3 |
5 files changed, 273 insertions, 11 deletions
diff --git a/src/conn.go b/src/conn.go index 7b35829..41a5b85 100644 --- a/src/conn.go +++ b/src/conn.go @@ -1,9 +1,31 @@ package main import ( + "errors" "net" ) +func parseEndpoint(s string) (*net.UDPAddr, error) { + + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + return addr, err +} + func updateUDPConn(device *Device) error { netc := &device.net netc.mutex.Lock() diff --git a/src/conn_linux.go b/src/conn_linux.go index e973b25..a349a9e 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -1,10 +1,253 @@ +/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. + */ + package main import ( + "errors" "golang.org/x/sys/unix" "net" + "strconv" + "unsafe" ) +/* Supports source address caching + * + * It is important that the endpoint is only updated after the packet content has been authenticated. + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + */ +type Endpoint struct { + // source (selected based on dst type) + // (could use RawSockaddrAny and unsafe) + srcIPv6 unix.RawSockaddrInet6 + srcIPv4 unix.RawSockaddrInet4 + srcIf4 int32 + + dst unix.RawSockaddrAny +} + +func zoneToUint32(zone string) (uint32, error) { + if zone == "" { + return 0, nil + } + if intr, err := net.InterfaceByName(zone); err == nil { + return uint32(intr.Index), nil + } + n, err := strconv.ParseUint(zone, 10, 32) + return uint32(n), err +} + +func (end *Endpoint) ClearSrc() { + end.srcIf4 = 0 + end.srcIPv4 = unix.RawSockaddrInet4{} + end.srcIPv6 = unix.RawSockaddrInet6{} +} + +func (end *Endpoint) Set(s string) error { + addr, err := parseEndpoint(s) + if err != nil { + return err + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return err + } + ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst)) + ptr.Family = unix.AF_INET6 + ptr.Port = uint16(addr.Port) + ptr.Flowinfo = 0 + ptr.Scope_id = zone + copy(ptr.Addr[:], ipv6[:]) + end.ClearSrc() + return nil + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + ptr.Family = unix.AF_INET + ptr.Port = uint16(addr.Port) + ptr.Zero = [8]byte{} + copy(ptr.Addr[:], ipv4) + end.ClearSrc() + return nil + } + + return errors.New("Failed to recognize IP address format") +} + +func send6(sock uintptr, end *Endpoint, buff []byte) error { + var iovec unix.Iovec + + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + Len: unix.SizeofInet6Pktinfo, + }, + unix.Inet6Pktinfo{ + Addr: end.srcIPv6.Addr, + Ifindex: end.srcIPv6.Scope_id, + }, + } + + msghdr := unix.Msghdr{ + Iov: &iovec, + Iovlen: 1, + Name: (*byte)(unsafe.Pointer(&end.dst)), + Namelen: unix.SizeofSockaddrInet6, + Control: (*byte)(unsafe.Pointer(&cmsg)), + } + + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // sendmsg(sock, &msghdr, 0) + + _, _, errno := unix.Syscall( + unix.SYS_SENDMSG, + sock, + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + if errno == unix.EINVAL { + end.ClearSrc() + } + return errno +} + +func send4(sock uintptr, end *Endpoint, buff []byte) error { + var iovec unix.Iovec + + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + Len: unix.SizeofInet6Pktinfo, + }, + unix.Inet4Pktinfo{ + Spec_dst: end.srcIPv4.Addr, + Ifindex: end.srcIf4, + }, + } + + msghdr := unix.Msghdr{ + Iov: &iovec, + Iovlen: 1, + Name: (*byte)(unsafe.Pointer(&end.dst)), + Namelen: unix.SizeofSockaddrInet4, + Control: (*byte)(unsafe.Pointer(&cmsg)), + } + + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // sendmsg(sock, &msghdr, 0) + + _, _, errno := unix.Syscall( + unix.SYS_SENDMSG, + sock, + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + if errno == unix.EINVAL { + end.ClearSrc() + } + return errno +} + +func send(c *net.UDPConn, end *Endpoint, buff []byte) error { + + // extract underlying file descriptor + + file, err := c.File() + if err != nil { + return err + } + sock := file.Fd() + + // send depending on address family of dst + + family := *((*uint16)(unsafe.Pointer(&end.dst))) + if family == unix.AF_INET { + return send4(sock, end, buff) + } else if family == unix.AF_INET6 { + return send6(sock, end, buff) + } + return errors.New("Unknown address family of source") +} + +func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) { + + file, err := c.File() + if err != nil { + return err, nil, nil + } + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo // big enough + } + + var msg unix.Msghdr + msg.Iov = &iovec + msg.Iovlen = 1 + msg.Name = (*byte)(unsafe.Pointer(&end.dst)) + msg.Namelen = uint32(unix.SizeofSockaddrAny) + msg.Control = (*byte)(unsafe.Pointer(&cmsg)) + msg.SetControllen(int(unsafe.Sizeof(cmsg))) + + _, _, errno := unix.Syscall( + unix.SYS_RECVMSG, + file.Fd(), + uintptr(unsafe.Pointer(&msg)), + 0, + ) + + if errno != 0 { + return errno, nil, nil + } + + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { + + } + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo)) + println(info) + + } + + return nil, nil, nil +} + func setMark(conn *net.UDPConn, value uint32) error { if conn == nil { return nil diff --git a/src/misc.go b/src/misc.go index d93849e..bbe0d68 100644 --- a/src/misc.go +++ b/src/misc.go @@ -29,6 +29,11 @@ func (a *AtomicBool) Set(val bool) { atomic.StoreInt32(&a.flag, flag) } +func toInt32(n uint32) int32 { + mask := uint32(1 << 31) + return int32(-(n & mask) + (n & ^mask)) +} + func min(a uint, b uint) uint { if a > b { return b diff --git a/src/tun_linux.go b/src/tun_linux.go index 58a762a..accc6c6 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -120,14 +120,6 @@ func (tun *NativeTun) Name() string { return tun.name } -func toInt32(val []byte) int32 { - n := binary.LittleEndian.Uint32(val[:4]) - if n >= (1 << 31) { - return -int32(^n) - 1 - } - return int32(n) -} - func getDummySock() (int, error) { return unix.Socket( unix.AF_INET, @@ -157,7 +149,8 @@ func getIFIndex(name string) (int32, error) { return 0, errno } - return toInt32(ifr[unix.IFNAMSIZ:]), nil + index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) + return toInt32(index), nil } func (tun *NativeTun) setMTU(n int) error { diff --git a/src/uapi.go b/src/uapi.go index 428b173..3a2f3f9 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -273,8 +273,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "endpoint": - // TODO: Only IP and port - addr, err := net.ResolveUDPAddr("udp", value) + addr, err := parseEndpoint(value) if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} |