diff options
-rw-r--r-- | conn/boundif_windows.go (renamed from device/boundif_windows.go) | 19 | ||||
-rw-r--r-- | conn/conn.go | 101 | ||||
-rw-r--r-- | conn/conn_default.go (renamed from device/conn_default.go) | 13 | ||||
-rw-r--r-- | conn/conn_linux.go (renamed from device/conn_linux.go) | 249 | ||||
-rw-r--r-- | conn/mark_default.go (renamed from device/mark_default.go) | 2 | ||||
-rw-r--r-- | conn/mark_unix.go (renamed from device/mark_unix.go) | 2 | ||||
-rw-r--r-- | device/bind_test.go | 14 | ||||
-rw-r--r-- | device/bindsocketshim.go | 36 | ||||
-rw-r--r-- | device/conn.go | 187 | ||||
-rw-r--r-- | device/device.go | 146 | ||||
-rw-r--r-- | device/peer.go | 6 | ||||
-rw-r--r-- | device/receive.go | 9 | ||||
-rw-r--r-- | device/sticky_default.go | 12 | ||||
-rw-r--r-- | device/sticky_linux.go | 215 | ||||
-rw-r--r-- | device/uapi.go | 3 |
15 files changed, 562 insertions, 452 deletions
diff --git a/device/boundif_windows.go b/conn/boundif_windows.go index 6908415..fe38d05 100644 --- a/device/boundif_windows.go +++ b/conn/boundif_windows.go @@ -3,11 +3,10 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "encoding/binary" - "errors" "unsafe" "golang.org/x/sys/windows" @@ -18,17 +17,13 @@ const ( sockoptIPV6_UNICAST_IF = 31 ) -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { +func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ bytes := make([]byte, 4) binary.BigEndian.PutUint32(bytes, interfaceIndex) interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() + sysconn, err := bind.ipv4.SyscallConn() if err != nil { return err } @@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo if err != nil { return err } - device.net.bind.(*nativeBind).blackhole4 = blackhole + bind.blackhole4 = blackhole return nil } -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() +func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + sysconn, err := bind.ipv6.SyscallConn() if err != nil { return err } @@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo if err != nil { return err } - device.net.bind.(*nativeBind).blackhole6 = blackhole + bind.blackhole6 = blackhole return nil } diff --git a/conn/conn.go b/conn/conn.go new file mode 100644 index 0000000..6b7db12 --- /dev/null +++ b/conn/conn.go @@ -0,0 +1,101 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +// Package conn implements WireGuard's network connections. +package conn + +import ( + "errors" + "net" + "strings" +) + +// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. +type Bind interface { + // LastMark reports the last mark set for this Bind. + LastMark() uint32 + + // SetMark sets the mark for each packet sent through this Bind. + // This mark is passed to the kernel as the socket option SO_MARK. + SetMark(mark uint32) error + + // ReceiveIPv6 reads an IPv6 UDP packet into b. + // + // It reports the number of bytes read, n, + // the packet source address ep, + // and any error. + ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error) + + // ReceiveIPv4 reads an IPv4 UDP packet into b. + // + // It reports the number of bytes read, n, + // the packet source address ep, + // and any error. + ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) + + // Send writes a packet b to address ep. + Send(b []byte, ep Endpoint) error + + // Close closes the Bind connection. + Close() error +} + +// CreateBind creates a Bind bound to a port. +// +// The value actualPort reports the actual port number the Bind +// object gets bound to. +func CreateBind(port uint16) (b Bind, actualPort uint16, err error) { + return createBind(port) +} + +// BindToInterface is implemented by Bind objects that support being +// tied to a single network interface. +type BindToInterface interface { + BindToInterface4(interfaceIndex uint32, blackhole bool) error + BindToInterface6(interfaceIndex uint32, blackhole bool) error +} + +// An Endpoint maintains the source/destination caching for a peer. +// +// dst : the remote address of a peer ("endpoint" in uapi terminology) +// src : the local address from which datagrams originate going to the peer +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + +func parseEndpoint(s string) (*net.UDPAddr, error) { + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { + // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just + // trying to make sure with a small sanity test that this is a real IP address and + // not something that's likely to incur DNS lookups. + host = host[:i] + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + ip4 := addr.IP.To4() + if ip4 != nil { + addr.IP = ip4 + } + return addr, err +} diff --git a/device/conn_default.go b/conn/conn_default.go index 661f57d..bad9d4d 100644 --- a/device/conn_default.go +++ b/conn/conn_default.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "net" @@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string { } func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) if err != nil { return nil, 0, err } - // retrieve port - + // Retrieve port. + // TODO(crawshaw): under what circumstances is this necessary? laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( laddr.Network(), @@ -100,7 +97,7 @@ func extractErrno(err error) error { return syscallErr.Err } -func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { +func createBind(uport uint16) (Bind, uint16, error) { var err error var bind nativeBind @@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error { return err2 } +func (bind *nativeBind) LastMark() uint32 { return 0 } + func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { if bind.ipv4 == nil { return 0, nil, syscall.EAFNOSUPPORT diff --git a/device/conn_linux.go b/conn/conn_linux.go index e90b0e3..523da4a 100644 --- a/device/conn_linux.go +++ b/conn/conn_linux.go @@ -3,18 +3,9 @@ /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. This is more or less a straight port - * of the sticky-sockets.c example code: - * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. */ -package device +package conn import ( "errors" @@ -25,7 +16,6 @@ import ( "unsafe" "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" ) const ( @@ -33,8 +23,8 @@ const ( ) type IPv4Source struct { - src [4]byte - ifindex int32 + Src [4]byte + Ifindex int32 } type IPv6Source struct { @@ -49,6 +39,10 @@ type NativeEndpoint struct { isV6 bool } +func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } +func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } +func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } + func (endpoint *NativeEndpoint) src4() *IPv4Source { return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) } @@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { } type nativeBind struct { - sock4 int - sock6 int - netlinkSock int - netlinkCancel *rwcancel.RWCancel - lastMark uint32 + sock4 int + sock6 int + lastMark uint32 } var _ Endpoint = (*NativeEndpoint)(nil) @@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) { return nil, errors.New("Invalid IP address") } -func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - return -1, err - } - saddr := &unix.SockaddrNetlink{ - Family: unix.AF_NETLINK, - Groups: unix.RTMGRP_IPV4_ROUTE, - } - err = unix.Bind(sock, saddr) - if err != nil { - unix.Close(sock) - return -1, err - } - return sock, nil - -} - -func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { +func createBind(port uint16) (Bind, uint16, error) { var err error var bind nativeBind var newPort uint16 - bind.netlinkSock, err = createNetlinkRouteSocket() - if err != nil { - return nil, 0, err - } - bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) - if err != nil { - unix.Close(bind.netlinkSock) - return nil, 0, err - } - - go bind.routineRouteListener(device) - - // attempt ipv6 bind, update port if successful - + // Attempt ipv6 bind, update port if successful. bind.sock6, newPort, err = create6(port) if err != nil { if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() return nil, 0, err } } else { port = newPort } - // attempt ipv4 bind, update port if successful - + // Attempt ipv4 bind, update port if successful. bind.sock4, newPort, err = create4(port) if err != nil { if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() unix.Close(bind.sock6) return nil, 0, err } @@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { return &bind, port, nil } +func (bind *nativeBind) LastMark() uint32 { + return bind.lastMark +} + func (bind *nativeBind) SetMark(value uint32) error { if bind.sock6 != -1 { err := unix.SetsockoptInt( @@ -216,22 +178,18 @@ func closeUnblock(fd int) error { } func (bind *nativeBind) Close() error { - var err1, err2, err3 error + var err1, err2 error if bind.sock6 != -1 { err1 = closeUnblock(bind.sock6) } if bind.sock4 != -1 { err2 = closeUnblock(bind.sock4) } - err3 = bind.netlinkCancel.Cancel() if err1 != nil { return err1 } - if err2 != nil { - return err2 - } - return err3 + return err2 } func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { @@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error { func (end *NativeEndpoint) SrcIP() net.IP { if !end.isV6 { return net.IPv4( - end.src4().src[0], - end.src4().src[1], - end.src4().src[2], - end.src4().src[3], + end.src4().Src[0], + end.src4().Src[1], + end.src4().Src[2], + end.src4().Src[3], ) } else { return end.src6().src[:] @@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error { Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, unix.Inet4Pktinfo{ - Spec_dst: end.src4().src, - Ifindex: end.src4().ifindex, + Spec_dst: end.src4().Src, + Ifindex: end.src4().Ifindex, }, } @@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { if cmsg.cmsghdr.Level == unix.IPPROTO_IP && cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().src = cmsg.pktinfo.Spec_dst - end.src4().ifindex = cmsg.pktinfo.Ifindex + end.src4().Src = cmsg.pktinfo.Spec_dst + end.src4().Ifindex = cmsg.pktinfo.Ifindex } return size, nil @@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { return size, nil } - -func (bind *nativeBind) routineRouteListener(device *Device) { - type peerEndpointPtr struct { - peer *Peer - endpoint *Endpoint - } - var reqPeer map[uint32]peerEndpointPtr - var reqPeerLock sync.Mutex - - defer unix.Close(bind.netlinkSock) - - for msg := make([]byte, 1<<16); ; { - var err error - var msgn int - for { - msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) - if err == nil || !rwcancel.RetryAfterError(err) { - break - } - if !bind.netlinkCancel.ReadyRead() { - return - } - } - if err != nil { - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if uint(hdr.Len) > uint(len(remain)) { - break - } - - switch hdr.Type { - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - if hdr.Seq <= MaxPeers && hdr.Seq > 0 { - if uint(len(remain)) < uint(hdr.Len) { - break - } - if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { - attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] - for { - if uint(len(attr)) < uint(unix.SizeofRtAttr) { - break - } - attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) - if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { - break - } - if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { - ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - reqPeerLock.Lock() - if reqPeer == nil { - reqPeerLock.Unlock() - break - } - pePtr, ok := reqPeer[hdr.Seq] - reqPeerLock.Unlock() - if !ok { - break - } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() - break - } - if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { - pePtr.peer.Unlock() - break - } - pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() - } - attr = attr[attrhdr.Len:] - } - } - break - } - reqPeerLock.Lock() - reqPeer = make(map[uint32]peerEndpointPtr) - reqPeerLock.Unlock() - go func() { - device.peers.RLock() - i := uint32(1) - for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { - peer.RUnlock() - continue - } - if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { - peer.RUnlock() - break - } - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: i, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - peer.endpoint.(*NativeEndpoint).dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - peer.endpoint.(*NativeEndpoint).src4().src, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_MARK, - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - reqPeerLock.Lock() - reqPeer[i] = peerEndpointPtr{ - peer: peer, - endpoint: &peer.endpoint, - } - reqPeerLock.Unlock() - peer.RUnlock() - i++ - _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) - if err != nil { - break - } - } - device.peers.RUnlock() - }() - } - remain = remain[hdr.Len:] - } - } -} diff --git a/device/mark_default.go b/conn/mark_default.go index 7de2524..fc41ba9 100644 --- a/device/mark_default.go +++ b/conn/mark_default.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn func (bind *nativeBind) SetMark(mark uint32) error { return nil diff --git a/device/mark_unix.go b/conn/mark_unix.go index 669b328..5334582 100644 --- a/device/mark_unix.go +++ b/conn/mark_unix.go @@ -5,7 +5,7 @@ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. */ -package device +package conn import ( "runtime" diff --git a/device/bind_test.go b/device/bind_test.go index 0c2e2cf..c5f7f68 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -5,11 +5,15 @@ package device -import "errors" +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) type DummyDatagram struct { msg []byte - endpoint Endpoint + endpoint conn.Endpoint world bool // better type } @@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") @@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") @@ -50,6 +54,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end Endpoint) error { +func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { return nil } diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go new file mode 100644 index 0000000..c4dd4ef --- /dev/null +++ b/device/bindsocketshim.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface4(interfaceIndex, blackhole) + } + return nil +} + +// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. +func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + if device.net.bind == nil { + return errors.New("Bind is not yet initialized") + } + + if iface, ok := device.net.bind.(conn.BindToInterface); ok { + return iface.BindToInterface6(interfaceIndex, blackhole) + } + return nil +} diff --git a/device/conn.go b/device/conn.go deleted file mode 100644 index 7b341f6..0000000 --- a/device/conn.go +++ /dev/null @@ -1,187 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "net" - "strings" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -const ( - ConnRoutineNumber = 2 -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} - -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - netc.stopping.Wait() - return err -} - -func (device *Device) BindSetMark(mark uint32) error { - - device.net.Lock() - defer device.net.Unlock() - - // check if modified - - if device.net.fwmark == mark { - return nil - } - - // update fwmark on existing bind - - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { - if err := device.net.bind.SetMark(mark); err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - return nil -} - -func (device *Device) BindUpdate() error { - - device.net.Lock() - defer device.net.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port, device) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(ConnRoutineNumber) - device.net.stopping.Add(ConnRoutineNumber) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.Lock() - err := unsafeCloseBind(device) - device.net.Unlock() - return err -} diff --git a/device/device.go b/device/device.go index 8c08f1c..a9fedea 100644 --- a/device/device.go +++ b/device/device.go @@ -11,15 +11,14 @@ import ( "sync/atomic" "time" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) @@ -39,9 +38,10 @@ type Device struct { starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind conn.Bind // bind interface + netlinkCancel *rwcancel.RWCancel + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } staticIdentity struct { @@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { cpus := runtime.NumCPU() device.state.starting.Wait() device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { + device.state.starting.Add(3) + device.state.stopping.Add(3) go device.RoutineEncryption() go device.RoutineDecryption() go device.RoutineHandshake() } + device.state.starting.Add(2) + device.state.stopping.Add(2) go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() @@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { } device.peers.RUnlock() } + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.netlinkCancel != nil { + netc.netlinkCancel.Cancel() + } + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(2) + device.net.stopping.Add(2) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} diff --git a/device/peer.go b/device/peer.go index 19434cd..79d4981 100644 --- a/device/peer.go +++ b/device/peer.go @@ -12,6 +12,8 @@ import ( "sync" "sync/atomic" "time" + + "golang.zx2c4.com/wireguard/conn" ) const ( @@ -24,7 +26,7 @@ type Peer struct { keypairs Keypairs handshake Handshake device *Device - endpoint Endpoint + endpoint conn.Endpoint persistentKeepaliveInterval uint16 // These fields are accessed with atomic operations, which must be @@ -290,7 +292,7 @@ func (peer *Peer) Stop() { var RoamingDisabled bool -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { +func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { if RoamingDisabled { return } diff --git a/device/receive.go b/device/receive.go index 7d0693e..4818d64 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,12 +17,13 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte - endpoint Endpoint + endpoint conn.Endpoint buffer *[MaxMessageSize]byte } @@ -33,7 +34,7 @@ type QueueInboundElement struct { packet []byte counter uint64 keypair *Keypair - endpoint Endpoint + endpoint conn.Endpoint } func (elem *QueueInboundElement) Drop() { @@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { +func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { logDebug := device.log.Debug defer func() { @@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { var ( err error size int - endpoint Endpoint + endpoint conn.Endpoint ) for { diff --git a/device/sticky_default.go b/device/sticky_default.go new file mode 100644 index 0000000..1cc52f6 --- /dev/null +++ b/device/sticky_default.go @@ -0,0 +1,12 @@ +// +build !linux + +package device + +import ( + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + return nil, nil +} diff --git a/device/sticky_linux.go b/device/sticky_linux.go new file mode 100644 index 0000000..f9522c2 --- /dev/null +++ b/device/sticky_linux.go @@ -0,0 +1,215 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "sync" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + netlinkSock, err := createNetlinkRouteSocket() + if err != nil { + return nil, err + } + netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) + if err != nil { + unix.Close(netlinkSock) + return nil, err + } + + go device.routineRouteListener(bind, netlinkSock, netlinkCancel) + + return netlinkCancel, nil +} + +func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { + type peerEndpointPtr struct { + peer *Peer + endpoint *conn.Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer unix.Close(netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.Lock() + if &pePtr.peer.endpoint != pePtr.endpoint { + pePtr.peer.Unlock() + break + } + if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { + pePtr.peer.Unlock() + break + } + pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() + pePtr.peer.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.RLock() + if peer.endpoint == nil { + peer.RUnlock() + continue + } + nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + if nativeEP == nil { + peer.RUnlock() + continue + } + if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { + peer.RUnlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + nativeEP.Dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + nativeEP.Src4().Src, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_MARK, + }, + uint32(bind.LastMark()), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint, + } + reqPeerLock.Unlock() + peer.RUnlock() + i++ + _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil +} diff --git a/device/uapi.go b/device/uapi.go index 72611ab..6cdccd6 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) @@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { err := func() error { peer.Lock() defer peer.Unlock() - endpoint, err := CreateEndpoint(value) + endpoint, err := conn.CreateEndpoint(value) if err != nil { return err } |