From a0f54cbe5ac2cd8b8296c2c57c30029dd349cff0 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 4 Feb 2018 16:08:26 +0100 Subject: Align with go library layout --- Makefile | 12 + build.cmd | 6 + conn.go | 128 ++++++++++ conn_default.go | 131 ++++++++++ conn_linux.go | 582 +++++++++++++++++++++++++++++++++++++++++++ constants.go | 43 ++++ cookie.go | 252 +++++++++++++++++++ cookie_test.go | 186 ++++++++++++++ daemon_darwin.go | 9 + daemon_linux.go | 32 +++ daemon_windows.go | 34 +++ device.go | 372 ++++++++++++++++++++++++++++ helper_test.go | 79 ++++++ index.go | 95 +++++++ ip.go | 17 ++ kdf_test.go | 79 ++++++ keypair.go | 44 ++++ logger.go | 50 ++++ main.go | 196 +++++++++++++++ misc.go | 57 +++++ noise_helpers.go | 98 ++++++++ noise_protocol.go | 578 +++++++++++++++++++++++++++++++++++++++++++ noise_test.go | 136 ++++++++++ noise_types.go | 74 ++++++ peer.go | 295 ++++++++++++++++++++++ ratelimiter.go | 139 +++++++++++ ratelimiter_test.go | 98 ++++++++ receive.go | 642 ++++++++++++++++++++++++++++++++++++++++++++++++ replay.go | 73 ++++++ replay_test.go | 112 +++++++++ routing.go | 65 +++++ send.go | 362 +++++++++++++++++++++++++++ signal.go | 53 ++++ src/Makefile | 12 - src/build.cmd | 6 - src/conn.go | 128 ---------- src/conn_default.go | 131 ---------- src/conn_linux.go | 582 ------------------------------------------- src/constants.go | 43 ---- src/cookie.go | 252 ------------------- src/cookie_test.go | 186 -------------- src/daemon_darwin.go | 9 - src/daemon_linux.go | 32 --- src/daemon_windows.go | 34 --- src/device.go | 372 ---------------------------- src/helper_test.go | 79 ------ src/index.go | 95 ------- src/ip.go | 17 -- src/kdf_test.go | 79 ------ src/keypair.go | 44 ---- src/logger.go | 50 ---- src/main.go | 196 --------------- src/misc.go | 57 ----- src/noise_helpers.go | 98 -------- src/noise_protocol.go | 578 ------------------------------------------- src/noise_test.go | 136 ---------- src/noise_types.go | 74 ------ src/peer.go | 295 ---------------------- src/ratelimiter.go | 139 ----------- src/ratelimiter_test.go | 98 -------- src/receive.go | 642 ------------------------------------------------ src/replay.go | 73 ------ src/replay_test.go | 112 --------- src/routing.go | 65 ----- src/send.go | 362 --------------------------- src/signal.go | 53 ---- src/tai64.go | 28 --- src/tests/netns.sh | 425 -------------------------------- src/timer.go | 59 ----- src/timers.go | 346 -------------------------- src/trie.go | 228 ----------------- src/trie_rand_test.go | 126 ---------- src/trie_test.go | 255 ------------------- src/tun.go | 58 ----- src/tun_darwin.go | 323 ------------------------ src/tun_linux.go | 377 ---------------------------- src/tun_windows.go | 475 ----------------------------------- src/uapi.go | 437 -------------------------------- src/uapi_darwin.go | 99 -------- src/uapi_linux.go | 171 ------------- src/uapi_windows.go | 44 ---- src/xchacha20.go | 169 ------------- src/xchacha20_test.go | 96 -------- tai64.go | 28 +++ tests/netns.sh | 425 ++++++++++++++++++++++++++++++++ timer.go | 59 +++++ timers.go | 346 ++++++++++++++++++++++++++ trie.go | 228 +++++++++++++++++ trie_rand_test.go | 126 ++++++++++ trie_test.go | 255 +++++++++++++++++++ tun.go | 58 +++++ tun_darwin.go | 323 ++++++++++++++++++++++++ tun_linux.go | 377 ++++++++++++++++++++++++++++ tun_windows.go | 475 +++++++++++++++++++++++++++++++++++ uapi.go | 437 ++++++++++++++++++++++++++++++++ uapi_darwin.go | 99 ++++++++ uapi_linux.go | 171 +++++++++++++ uapi_windows.go | 44 ++++ xchacha20.go | 169 +++++++++++++ xchacha20_test.go | 96 ++++++++ 100 files changed, 8845 insertions(+), 8845 deletions(-) create mode 100644 Makefile create mode 100755 build.cmd create mode 100644 conn.go create mode 100644 conn_default.go create mode 100644 conn_linux.go create mode 100644 constants.go create mode 100644 cookie.go create mode 100644 cookie_test.go create mode 100644 daemon_darwin.go create mode 100644 daemon_linux.go create mode 100644 daemon_windows.go create mode 100644 device.go create mode 100644 helper_test.go create mode 100644 index.go create mode 100644 ip.go create mode 100644 kdf_test.go create mode 100644 keypair.go create mode 100644 logger.go create mode 100644 main.go create mode 100644 misc.go create mode 100644 noise_helpers.go create mode 100644 noise_protocol.go create mode 100644 noise_test.go create mode 100644 noise_types.go create mode 100644 peer.go create mode 100644 ratelimiter.go create mode 100644 ratelimiter_test.go create mode 100644 receive.go create mode 100644 replay.go create mode 100644 replay_test.go create mode 100644 routing.go create mode 100644 send.go create mode 100644 signal.go delete mode 100644 src/Makefile delete mode 100755 src/build.cmd delete mode 100644 src/conn.go delete mode 100644 src/conn_default.go delete mode 100644 src/conn_linux.go delete mode 100644 src/constants.go delete mode 100644 src/cookie.go delete mode 100644 src/cookie_test.go delete mode 100644 src/daemon_darwin.go delete mode 100644 src/daemon_linux.go delete mode 100644 src/daemon_windows.go delete mode 100644 src/device.go delete mode 100644 src/helper_test.go delete mode 100644 src/index.go delete mode 100644 src/ip.go delete mode 100644 src/kdf_test.go delete mode 100644 src/keypair.go delete mode 100644 src/logger.go delete mode 100644 src/main.go delete mode 100644 src/misc.go delete mode 100644 src/noise_helpers.go delete mode 100644 src/noise_protocol.go delete mode 100644 src/noise_test.go delete mode 100644 src/noise_types.go delete mode 100644 src/peer.go delete mode 100644 src/ratelimiter.go delete mode 100644 src/ratelimiter_test.go delete mode 100644 src/receive.go delete mode 100644 src/replay.go delete mode 100644 src/replay_test.go delete mode 100644 src/routing.go delete mode 100644 src/send.go delete mode 100644 src/signal.go delete mode 100644 src/tai64.go delete mode 100755 src/tests/netns.sh delete mode 100644 src/timer.go delete mode 100644 src/timers.go delete mode 100644 src/trie.go delete mode 100644 src/trie_rand_test.go delete mode 100644 src/trie_test.go delete mode 100644 src/tun.go delete mode 100644 src/tun_darwin.go delete mode 100644 src/tun_linux.go delete mode 100644 src/tun_windows.go delete mode 100644 src/uapi.go delete mode 100644 src/uapi_darwin.go delete mode 100644 src/uapi_linux.go delete mode 100644 src/uapi_windows.go delete mode 100644 src/xchacha20.go delete mode 100644 src/xchacha20_test.go create mode 100644 tai64.go create mode 100755 tests/netns.sh create mode 100644 timer.go create mode 100644 timers.go create mode 100644 trie.go create mode 100644 trie_rand_test.go create mode 100644 trie_test.go create mode 100644 tun.go create mode 100644 tun_darwin.go create mode 100644 tun_linux.go create mode 100644 tun_windows.go create mode 100644 uapi.go create mode 100644 uapi_darwin.go create mode 100644 uapi_linux.go create mode 100644 uapi_windows.go create mode 100644 xchacha20.go create mode 100644 xchacha20_test.go diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..5b23ecc --- /dev/null +++ b/Makefile @@ -0,0 +1,12 @@ +all: wireguard-go + +wireguard-go: $(wildcard *.go) + go build -o $@ + +clean: + rm -f wireguard-go + +cloc: + cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go)) + +.PHONY: clean cloc diff --git a/build.cmd b/build.cmd new file mode 100755 index 0000000..52cb883 --- /dev/null +++ b/build.cmd @@ -0,0 +1,6 @@ +@echo off + +REM builds wireguard for windows + +go get +go build -o wireguard-go.exe diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..fb30ec2 --- /dev/null +++ b/conn.go @@ -0,0 +1,128 @@ +package main + +import ( + "errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" +) + +/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic + */ +type Bind interface { + SetMark(value uint32) error + ReceiveIPv6(buff []byte) (int, Endpoint, error) + ReceiveIPv4(buff []byte) (int, Endpoint, error) + Send(buff []byte, end Endpoint) error + Close() error +} + +/* An Endpoint maintains the source/destination caching for a peer + * + * dst : the remote address of a peer ("endpoint" in uapi terminology) + * src : the local address from which datagrams originate going to the peer + */ +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + +func parseEndpoint(s string) (*net.UDPAddr, error) { + + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + return addr, err +} + +/* Must hold device and net lock + */ +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + return err +} + +func (device *Device) BindUpdate() error { + + device.net.mutex.Lock() + defer device.net.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = CreateBind(netc.port) + if err != nil { + netc.bind = nil + return err + } + + // set mark + + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + + // clear cached source addresses + + for _, peer := range device.peers.keyMap { + peer.mutex.Lock() + defer peer.mutex.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + + // start receiving routines + + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.mutex.Lock() + err := unsafeCloseBind(device) + device.net.mutex.Unlock() + return err +} diff --git a/conn_default.go b/conn_default.go new file mode 100644 index 0000000..5b73c90 --- /dev/null +++ b/conn_default.go @@ -0,0 +1,131 @@ +// +build !linux + +package main + +import ( + "net" +) + +/* This code is meant to be a temporary solution + * on platforms for which the sticky socket / source caching behavior + * has not yet been implemented. + * + * See conn_linux.go for an implementation on the linux platform. + */ + +type NativeBind struct { + ipv4 *net.UDPConn + ipv6 *net.UDPConn +} + +type NativeEndpoint net.UDPAddr + +var _ Bind = (*NativeBind)(nil) +var _ Endpoint = (*NativeEndpoint)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + addr, err := parseEndpoint(s) + return (*NativeEndpoint)(addr), err +} + +func (_ *NativeEndpoint) ClearSrc() {} + +func (e *NativeEndpoint) DstIP() net.IP { + return (*net.UDPAddr)(e).IP +} + +func (e *NativeEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *NativeEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} + +func (e *NativeEndpoint) DstToString() string { + return (*net.UDPAddr)(e).String() +} + +func (e *NativeEndpoint) SrcToString() string { + return "" +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + + // listen + + conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // retrieve port + + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn, uaddr.Port, nil +} + +func CreateBind(uport uint16) (Bind, uint16, error) { + var err error + var bind NativeBind + + port := int(uport) + + bind.ipv4, port, err = listenNet("udp4", port) + if err != nil { + return nil, 0, err + } + + bind.ipv6, port, err = listenNet("udp6", port) + if err != nil { + bind.ipv4.Close() + return nil, 0, err + } + + return &bind, uint16(port), nil +} + +func (bind *NativeBind) Close() error { + err1 := bind.ipv4.Close() + err2 := bind.ipv6.Close() + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + n, endpoint, err := bind.ipv4.ReadFromUDP(buff) + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + n, endpoint, err := bind.ipv6.ReadFromUDP(buff) + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { + var err error + nend := endpoint.(*NativeEndpoint) + if nend.IP.To16() != nil { + _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } else { + _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } + return err +} + +func (bind *NativeBind) SetMark(_ uint32) error { + return nil +} diff --git a/conn_linux.go b/conn_linux.go new file mode 100644 index 0000000..cdba74f --- /dev/null +++ b/conn_linux.go @@ -0,0 +1,582 @@ +/* Copyright 2017 Jason A. Donenfeld . All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. + */ + +package main + +import ( + "encoding/binary" + "errors" + "golang.org/x/sys/unix" + "net" + "strconv" + "unsafe" +) + +/* Supports source address caching + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ +type NativeEndpoint struct { + src unix.RawSockaddrInet6 + dst unix.RawSockaddrInet6 +} + +type NativeBind struct { + sock4 int + sock6 int +} + +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = NativeBind{} + +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 +} + +func htons(val uint16) uint16 { + var out [unsafe.Sizeof(val)]byte + binary.BigEndian.PutUint16(out[:], val) + return *((*uint16)(unsafe.Pointer(&out[0]))) +} + +func ntohs(val uint16) uint16 { + tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val))) + return binary.BigEndian.Uint16((*tmp)[:]) +} + +func CreateEndpoint(s string) (Endpoint, error) { + var end NativeEndpoint + addr, err := parseEndpoint(s) + if err != nil { + return nil, err + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + dst.Family = unix.AF_INET + dst.Port = htons(uint16(addr.Port)) + dst.Zero = [8]byte{} + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return &end, nil + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return nil, err + } + dst := &end.dst + dst.Family = unix.AF_INET6 + dst.Port = htons(uint16(addr.Port)) + dst.Flowinfo = 0 + dst.Scope_id = zone + copy(dst.Addr[:], ipv6[:]) + end.ClearSrc() + return &end, nil + } + + return nil, errors.New("Failed to recognize IP address format") +} + +func CreateBind(port uint16) (Bind, uint16, error) { + var err error + var bind NativeBind + + bind.sock6, port, err = create6(port) + if err != nil { + return nil, port, err + } + + bind.sock4, port, err = create4(port) + if err != nil { + unix.Close(bind.sock6) + } + return bind, port, err +} + +func (bind NativeBind) SetMark(value uint32) error { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + + return unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) +} + +func closeUnblock(fd int) error { + // shutdown to unblock readers + unix.Shutdown(fd, unix.SHUT_RD) + return unix.Close(fd) +} + +func (bind NativeBind) Close() error { + err1 := closeUnblock(bind.sock6) + err2 := closeUnblock(bind.sock4) + if err1 != nil { + return err1 + } + return err2 +} + +func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive6( + bind.sock6, + buff, + &end, + ) + return n, &end, err +} + +func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive4( + bind.sock4, + buff, + &end, + ) + return n, &end, err +} + +func (bind NativeBind) Send(buff []byte, end Endpoint) error { + nend := end.(*NativeEndpoint) + switch nend.dst.Family { + case unix.AF_INET6: + return send6(bind.sock6, nend, buff) + case unix.AF_INET: + return send4(bind.sock4, nend, buff) + default: + return errors.New("Unknown address family of destination") + } +} + +func sockaddrToString(addr unix.RawSockaddrInet6) string { + var udpAddr net.UDPAddr + + switch addr.Family { + case unix.AF_INET6: + udpAddr.Port = int(ntohs(addr.Port)) + udpAddr.IP = addr.Addr[:] + return udpAddr.String() + + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) + udpAddr.Port = int(ntohs(ptr.Port)) + udpAddr.IP = net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + return udpAddr.String() + + default: + return "" + } +} + +func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { + switch addr.Family { + case unix.AF_INET6: + return addr.Addr[:] + case unix.AF_INET: + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) + return net.IPv4( + ptr.Addr[0], + ptr.Addr[1], + ptr.Addr[2], + ptr.Addr[3], + ) + default: + return nil + } +} + +func (end *NativeEndpoint) SrcIP() net.IP { + return rawAddrToIP(end.src) +} + +func (end *NativeEndpoint) DstIP() net.IP { + return rawAddrToIP(end.dst) +} + +func (end *NativeEndpoint) DstToBytes() []byte { + ptr := unsafe.Pointer(&end.src) + arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) + return arr[:] +} + +func (end *NativeEndpoint) SrcToString() string { + return sockaddrToString(end.src) +} + +func (end *NativeEndpoint) DstToString() string { + return sockaddrToString(end.dst) +} + +func (end *NativeEndpoint) ClearDst() { + end.dst = unix.RawSockaddrInet6{} +} + +func (end *NativeEndpoint) ClearSrc() { + end.src = unix.RawSockaddrInet6{} +} + +func zoneToUint32(zone string) (uint32, error) { + if zone == "" { + return 0, nil + } + if intr, err := net.InterfaceByName(zone); err == nil { + return uint32(intr.Index), nil + } + n, err := strconv.ParseUint(zone, 10, 32) + return uint32(n), err +} + +func create4(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, 0, err + } + + addr := unix.SockaddrInet4{ + Port: int(port), + } + + // set sockopts and bind + + if err := func() error { + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_PKTINFO, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + }(); err != nil { + unix.Close(fd) + } + + return fd, uint16(addr.Port), err +} + +func create6(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, 0, err + } + + // set sockopts and bind + + addr := unix.SockaddrInet6{ + Port: int(port), + } + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVPKTINFO, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_V6ONLY, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + } + + return fd, uint16(addr.Port), err +} + +func send6(sock int, end *NativeEndpoint, buff []byte) error { + + // construct message header + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet6Pktinfo{ + Addr: end.src.Addr, + Ifindex: end.src.Scope_id, + }, + } + + msghdr := unix.Msghdr{ + Iov: &iovec, + Iovlen: 1, + Name: (*byte)(unsafe.Pointer(&end.dst)), + Namelen: unix.SizeofSockaddrInet6, + Control: (*byte)(unsafe.Pointer(&cmsg)), + } + + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // sendmsg(sock, &msghdr, 0) + + _, _, errno := unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + + if errno == 0 { + return nil + } + + // clear src and retry + + if errno == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, _, errno = unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + } + + return errno +} + +func send4(sock int, end *NativeEndpoint, buff []byte) error { + + // construct message header + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet4Pktinfo{ + Spec_dst: src4.src.Addr, + Ifindex: src4.Ifindex, + }, + } + + msghdr := unix.Msghdr{ + Iov: &iovec, + Iovlen: 1, + Name: (*byte)(unsafe.Pointer(&end.dst)), + Namelen: unix.SizeofSockaddrInet4, + Control: (*byte)(unsafe.Pointer(&cmsg)), + Flags: 0, + } + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // sendmsg(sock, &msghdr, 0) + + _, _, errno := unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + + // clear source and try again + + if errno == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet4Pktinfo{} + _, _, errno = unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + } + + // errno = 0 is still an error instance + + if errno == 0 { + return nil + } + + return errno +} + +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { + + // contruct message header + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + } + + var msghdr unix.Msghdr + msghdr.Iov = &iovec + msghdr.Iovlen = 1 + msghdr.Name = (*byte)(unsafe.Pointer(&end.dst)) + msghdr.Namelen = unix.SizeofSockaddrInet4 + msghdr.Control = (*byte)(unsafe.Pointer(&cmsg)) + msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) + + // recvmsg(sock, &mskhdr, 0) + + size, _, errno := unix.Syscall( + unix.SYS_RECVMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) + + if errno != 0 { + return 0, errno + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) + src4.src.Family = unix.AF_INET + src4.src.Addr = cmsg.pktinfo.Spec_dst + src4.Ifindex = cmsg.pktinfo.Ifindex + } + + return int(size), nil +} + +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { + + // contruct message header + + var iovec unix.Iovec + iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) + iovec.SetLen(len(buff)) + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + } + + var msg unix.Msghdr + msg.Iov = &iovec + msg.Iovlen = 1 + msg.Name = (*byte)(unsafe.Pointer(&end.dst)) + msg.Namelen = uint32(unix.SizeofSockaddrInet6) + msg.Control = (*byte)(unsafe.Pointer(&cmsg)) + msg.SetControllen(int(unsafe.Sizeof(cmsg))) + + // recvmsg(sock, &mskhdr, 0) + + size, _, errno := unix.Syscall( + unix.SYS_RECVMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msg)), + 0, + ) + + if errno != 0 { + return 0, errno + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { + end.src.Family = unix.AF_INET6 + end.src.Addr = cmsg.pktinfo.Addr + end.src.Scope_id = cmsg.pktinfo.Ifindex + } + + return int(size), nil +} diff --git a/constants.go b/constants.go new file mode 100644 index 0000000..71dd98e --- /dev/null +++ b/constants.go @@ -0,0 +1,43 @@ +package main + +import ( + "time" +) + +/* Specification constants */ + +const ( + RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 + RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 + RejectAfterTime = time.Second * 180 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 120 + HandshakeInitationRate = time.Second / 20 + PaddingMultiple = 16 +) + +const ( + RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout + NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message +) + +/* Implementation specific constants */ + +const ( + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = MaxSegmentSize // maximum size of transport message + MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content +) + +const ( + UnderLoadQueueSize = QueueHandshakeSize / 8 + UnderLoadAfterTime = time.Second // how long does the device remain under load after detected + MaxPeers = 1 << 16 // maximum number of configured peers +) diff --git a/cookie.go b/cookie.go new file mode 100644 index 0000000..a13ad49 --- /dev/null +++ b/cookie.go @@ -0,0 +1,252 @@ +package main + +import ( + "crypto/hmac" + "crypto/rand" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "sync" + "time" +) + +type CookieChecker struct { + mutex sync.RWMutex + mac1 struct { + key [blake2s.Size]byte + } + mac2 struct { + secret [blake2s.Size]byte + secretSet time.Time + encryptionKey [chacha20poly1305.KeySize]byte + } +} + +type CookieGenerator struct { + mutex sync.RWMutex + mac1 struct { + key [blake2s.Size]byte + } + mac2 struct { + cookie [blake2s.Size128]byte + cookieSet time.Time + hasLastMAC1 bool + lastMAC1 [blake2s.Size128]byte + encryptionKey [chacha20poly1305.KeySize]byte + } +} + +func (st *CookieChecker) Init(pk NoisePublicKey) { + st.mutex.Lock() + defer st.mutex.Unlock() + + // mac1 state + + func() { + hsh, _ := blake2s.New256(nil) + hsh.Write([]byte(WGLabelMAC1)) + hsh.Write(pk[:]) + hsh.Sum(st.mac1.key[:0]) + }() + + // mac2 state + + func() { + hsh, _ := blake2s.New256(nil) + hsh.Write([]byte(WGLabelCookie)) + hsh.Write(pk[:]) + hsh.Sum(st.mac2.encryptionKey[:0]) + }() + + st.mac2.secretSet = time.Time{} +} + +func (st *CookieChecker) CheckMAC1(msg []byte) bool { + size := len(msg) + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + var mac1 [blake2s.Size128]byte + + mac, _ := blake2s.New128(st.mac1.key[:]) + mac.Write(msg[:smac1]) + mac.Sum(mac1[:0]) + + return hmac.Equal(mac1[:], msg[smac1:smac2]) +} + +func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { + st.mutex.RLock() + defer st.mutex.RUnlock() + + if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { + return false + } + + // derive cookie key + + var cookie [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(st.mac2.secret[:]) + mac.Write(src) + mac.Sum(cookie[:0]) + }() + + // calculate mac of packet (including mac1) + + smac2 := len(msg) - blake2s.Size128 + + var mac2 [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(cookie[:]) + mac.Write(msg[:smac2]) + mac.Sum(mac2[:0]) + }() + + return hmac.Equal(mac2[:], msg[smac2:]) +} + +func (st *CookieChecker) CreateReply( + msg []byte, + recv uint32, + src []byte, +) (*MessageCookieReply, error) { + + st.mutex.RLock() + + // refresh cookie secret + + if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { + st.mutex.RUnlock() + st.mutex.Lock() + _, err := rand.Read(st.mac2.secret[:]) + if err != nil { + st.mutex.Unlock() + return nil, err + } + st.mac2.secretSet = time.Now() + st.mutex.Unlock() + st.mutex.RLock() + } + + // derive cookie + + var cookie [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(st.mac2.secret[:]) + mac.Write(src) + mac.Sum(cookie[:0]) + }() + + // encrypt cookie + + size := len(msg) + + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + reply := new(MessageCookieReply) + reply.Type = MessageCookieReplyType + reply.Receiver = recv + + _, err := rand.Read(reply.Nonce[:]) + if err != nil { + st.mutex.RUnlock() + return nil, err + } + + XChaCha20Poly1305Encrypt( + reply.Cookie[:0], + &reply.Nonce, + cookie[:], + msg[smac1:smac2], + &st.mac2.encryptionKey, + ) + + st.mutex.RUnlock() + + return reply, nil +} + +func (st *CookieGenerator) Init(pk NoisePublicKey) { + st.mutex.Lock() + defer st.mutex.Unlock() + + func() { + hsh, _ := blake2s.New256(nil) + hsh.Write([]byte(WGLabelMAC1)) + hsh.Write(pk[:]) + hsh.Sum(st.mac1.key[:0]) + }() + + func() { + hsh, _ := blake2s.New256(nil) + hsh.Write([]byte(WGLabelCookie)) + hsh.Write(pk[:]) + hsh.Sum(st.mac2.encryptionKey[:0]) + }() + + st.mac2.cookieSet = time.Time{} +} + +func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { + st.mutex.Lock() + defer st.mutex.Unlock() + + if !st.mac2.hasLastMAC1 { + return false + } + + var cookie [blake2s.Size128]byte + + _, err := XChaCha20Poly1305Decrypt( + cookie[:0], + &msg.Nonce, + msg.Cookie[:], + st.mac2.lastMAC1[:], + &st.mac2.encryptionKey, + ) + + if err != nil { + return false + } + + st.mac2.cookieSet = time.Now() + st.mac2.cookie = cookie + return true +} + +func (st *CookieGenerator) AddMacs(msg []byte) { + + size := len(msg) + + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + mac1 := msg[smac1:smac2] + mac2 := msg[smac2:] + + st.mutex.Lock() + defer st.mutex.Unlock() + + // set mac1 + + func() { + mac, _ := blake2s.New128(st.mac1.key[:]) + mac.Write(msg[:smac1]) + mac.Sum(mac1[:0]) + }() + copy(st.mac2.lastMAC1[:], mac1) + st.mac2.hasLastMAC1 = true + + // set mac2 + + if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { + return + } + + func() { + mac, _ := blake2s.New128(st.mac2.cookie[:]) + mac.Write(msg[:smac2]) + mac.Sum(mac2[:0]) + }() +} diff --git a/cookie_test.go b/cookie_test.go new file mode 100644 index 0000000..d745fe7 --- /dev/null +++ b/cookie_test.go @@ -0,0 +1,186 @@ +package main + +import ( + "testing" +) + +func TestCookieMAC1(t *testing.T) { + + // setup generator / checker + + var ( + generator CookieGenerator + checker CookieChecker + ) + + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + pk := sk.publicKey() + + generator.Init(pk) + checker.Init(pk) + + // check mac1 + + src := []byte{192, 168, 13, 37, 10, 10, 10} + + checkMAC1 := func(msg []byte) { + generator.AddMacs(msg) + if !checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + } + + checkMAC1([]byte{ + 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, + 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, + 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, + 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, + 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, + 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, + 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, + 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, + 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, + 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, + 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, + 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, + 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, + 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, + 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, + 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, + }) + + checkMAC1([]byte{ + 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, + 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, + 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, + 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, + 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, + 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, + 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, + 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, + }) + + checkMAC1([]byte{ + 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, + 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, + 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, + 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, + }) + + // exchange cookie reply + + func() { + msg := []byte{ + 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, + 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, + 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, + 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, + 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, + 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, + 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, + 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, + 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, + 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, + 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, + 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, + 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, + 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, + 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, + 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, + } + generator.AddMacs(msg) + reply, err := checker.CreateReply(msg, 1377, src) + if err != nil { + t.Fatal("Failed to create cookie reply:", err) + } + if !generator.ConsumeReply(reply) { + t.Fatal("Failed to consume cookie reply") + } + }() + + // check mac2 + + checkMAC2 := func(msg []byte) { + generator.AddMacs(msg) + + if !checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if !checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + + msg[5] ^= 0x20 + + if checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + + msg[5] ^= 0x20 + + srcBad1 := []byte{192, 168, 13, 37, 40, 01} + if checker.CheckMAC2(msg, srcBad1) { + t.Fatal("MAC2 generation/verification failed") + } + + srcBad2 := []byte{192, 168, 13, 38, 40, 01} + if checker.CheckMAC2(msg, srcBad2) { + t.Fatal("MAC2 generation/verification failed") + } + } + + checkMAC2([]byte{ + 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, + 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, + 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, + 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, + 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, + 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, + 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, + 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, + }) + + checkMAC2([]byte{ + 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, + 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, + 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, + 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, + 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, + 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, + 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, + 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, + 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, + 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, + 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, + 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, + 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, + 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, + 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, + 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, + 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, + 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, + 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, + 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, + 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, + 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, + 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, + 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, + 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, + 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, + 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, + 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, + 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, + 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, + 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, + 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, + }) +} diff --git a/daemon_darwin.go b/daemon_darwin.go new file mode 100644 index 0000000..913af0e --- /dev/null +++ b/daemon_darwin.go @@ -0,0 +1,9 @@ +package main + +import ( + "errors" +) + +func Daemonize() error { + return errors.New("Not implemented on OSX") +} diff --git a/daemon_linux.go b/daemon_linux.go new file mode 100644 index 0000000..e1aaede --- /dev/null +++ b/daemon_linux.go @@ -0,0 +1,32 @@ +package main + +import ( + "os" + "os/exec" +) + +/* Daemonizes the process on linux + * + * This is done by spawning and releasing a copy with the --foreground flag + */ +func Daemonize(attr *os.ProcAttr) error { + // I would like to use os.Executable, + // however this means dropping support for Go <1.8 + path, err := exec.LookPath(os.Args[0]) + if err != nil { + return err + } + + argv := []string{os.Args[0], "--foreground"} + argv = append(argv, os.Args[1:]...) + process, err := os.StartProcess( + path, + argv, + attr, + ) + if err != nil { + return err + } + process.Release() + return nil +} diff --git a/daemon_windows.go b/daemon_windows.go new file mode 100644 index 0000000..d5ec1e8 --- /dev/null +++ b/daemon_windows.go @@ -0,0 +1,34 @@ +package main + +import ( + "os" +) + +/* Daemonizes the process on windows + * + * This is done by spawning and releasing a copy with the --foreground flag + */ + +func Daemonize() error { + argv := []string{os.Args[0], "--foreground"} + argv = append(argv, os.Args[1:]...) + attr := &os.ProcAttr{ + Dir: ".", + Env: os.Environ(), + Files: []*os.File{ + os.Stdin, + nil, + nil, + }, + } + process, err := os.StartProcess( + argv[0], + argv, + attr, + ) + if err != nil { + return err + } + process.Release() + return nil +} diff --git a/device.go b/device.go new file mode 100644 index 0000000..c041987 --- /dev/null +++ b/device.go @@ -0,0 +1,372 @@ +package main + +import ( + "github.com/sasha-s/go-deadlock" + "runtime" + "sync" + "sync/atomic" + "time" +) + +type Device struct { + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + + state struct { + mutex deadlock.Mutex + changing AtomicBool + current bool + } + + net struct { + mutex deadlock.RWMutex + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + } + + noise struct { + mutex deadlock.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + routing struct { + mutex deadlock.RWMutex + table RoutingTable + } + + peers struct { + mutex deadlock.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + indices IndexTable + mac CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter Ratelimiter + } + + pool struct { + messageBuffers sync.Pool + } + + queue struct { + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + } + + signal struct { + stop Signal + } + + tun struct { + device TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold: + * device.peers.mutex : exclusive lock + * device.routing : exclusive lock + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.routing.table.RemovePeer(peer) + peer.Stop() + + // remove from peer map + + delete(device.peers.keyMap, key) +} + +func deviceUpdateState(device *Device) { + + // check if state already being updated (guard) + + if device.state.changing.Swap(true) { + return + } + + func() { + + // compare to current state of device + + device.state.mutex.Lock() + defer device.state.mutex.Unlock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + peer.Start() + } + + case false: + device.BindClose() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + println("stopping peer") + peer.Stop() + } + } + + // update state variables + + device.state.current = newIsUp + device.state.changing.Set(false) + }() + + // check for state change in the mean time + + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.state.mutex.Lock() + device.isUp.Set(true) + device.state.mutex.Unlock() + deviceUpdateState(device) +} + +func (device *Device) Down() { + device.state.mutex.Lock() + device.isUp.Set(false) + device.state.mutex.Unlock() + deviceUpdateState(device) +} + +func (device *Device) IsUnderLoad() bool { + + // check if currently under load + + now := time.Now() + underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + if underLoad { + device.rate.underLoadUntil.Store(now.Add(time.Second)) + return true + } + + // check if recently under load + + until := device.rate.underLoadUntil.Load().(time.Time) + return until.After(now) +} + +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { + + // lock required resources + + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } + + // remove peers with matching public keys + + publicKey := sk.publicKey() + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) + } + } + + // update key material + + device.noise.privateKey = sk + device.noise.publicKey = publicKey + device.mac.Init(publicKey) + + // do static-static DH pre-computations + + rmKey := device.noise.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + hs := &peer.handshake + + if rmKey { + hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) + } + + if isZero(hs.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) + } + } + + return nil +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func NewDevice(tun TUNDevice, logger *Logger) *Device { + device := new(Device) + + device.isUp.Set(false) + device.isClosed.Set(false) + + device.log = logger + device.tun.device = tun + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + // initialize anti-DoS / anti-scanning features + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + // initialize noise & crypt-key routine + + device.indices.Init() + device.routing.table.Reset() + + // setup buffer pool + + device.pool.messageBuffers = sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + + // create queues + + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signal.stop = NewSignal() + + // prepare net + + device.net.port = 0 + device.net.bind = nil + + // start workers + + for i := 0; i < runtime.NumCPU(); i += 1 { + go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() + } + + go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() + go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) + + return device +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.peers.mutex.RLock() + defer device.peers.mutex.RUnlock() + + return device.peers.keyMap[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } +} + +func (device *Device) RemoveAllPeers() { + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for key, peer := range device.peers.keyMap { + println("rm", peer.String()) + unsafeRemovePeer(device, peer, key) + } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) +} + +func (device *Device) Close() { + device.log.Info.Println("Device closing") + if device.isClosed.Swap(true) { + return + } + device.signal.stop.Broadcast() + device.tun.device.Close() + device.BindClose() + device.isUp.Set(false) + device.RemoveAllPeers() + device.log.Info.Println("Interface closed") +} + +func (device *Device) Wait() chan struct{} { + return device.signal.stop.Wait() +} diff --git a/helper_test.go b/helper_test.go new file mode 100644 index 0000000..41e6b72 --- /dev/null +++ b/helper_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "bytes" + "os" + "testing" +) + +/* Helpers for writing unit tests + */ + +type DummyTUN struct { + name string + mtu int + packets chan []byte + events chan TUNEvent +} + +func (tun *DummyTUN) File() *os.File { + return nil +} + +func (tun *DummyTUN) Name() string { + return tun.name +} + +func (tun *DummyTUN) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { + tun.packets <- d[offset:] + return len(d), nil +} + +func (tun *DummyTUN) Close() error { + return nil +} + +func (tun *DummyTUN) Events() chan TUNEvent { + return tun.events +} + +func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { + t := <-tun.packets + copy(d[offset:], t) + return len(t), nil +} + +func CreateDummyTUN(name string) (TUNDevice, error) { + var dummy DummyTUN + dummy.mtu = 0 + dummy.packets = make(chan []byte, 100) + return &dummy, nil +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a []byte, b []byte) { + if bytes.Compare(a, b) != 0 { + t.Fatal(a, "!=", b) + } +} + +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun, _ := CreateDummyTUN("dummy") + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun, logger) + device.SetPrivateKey(sk) + return device +} diff --git a/index.go b/index.go new file mode 100644 index 0000000..1ba040e --- /dev/null +++ b/index.go @@ -0,0 +1,95 @@ +package main + +import ( + "crypto/rand" + "encoding/binary" + "sync" +) + +/* Index=0 is reserved for unset indecies + * + */ + +type IndexTableEntry struct { + peer *Peer + handshake *Handshake + keyPair *KeyPair +} + +type IndexTable struct { + mutex sync.RWMutex + table map[uint32]IndexTableEntry +} + +func randUint32() (uint32, error) { + var buff [4]byte + _, err := rand.Read(buff[:]) + value := binary.LittleEndian.Uint32(buff[:]) + return value, err +} + +func (table *IndexTable) Init() { + table.mutex.Lock() + table.table = make(map[uint32]IndexTableEntry) + table.mutex.Unlock() +} + +func (table *IndexTable) Delete(index uint32) { + if index == 0 { + return + } + table.mutex.Lock() + delete(table.table, index) + table.mutex.Unlock() +} + +func (table *IndexTable) Insert(key uint32, value IndexTableEntry) { + table.mutex.Lock() + table.table[key] = value + table.mutex.Unlock() +} + +func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { + for { + // generate random index + + index, err := randUint32() + if err != nil { + return index, err + } + if index == 0 { + continue + } + + // check if index used + + table.mutex.RLock() + _, ok := table.table[index] + table.mutex.RUnlock() + if ok { + continue + } + + // map index to handshake + + table.mutex.Lock() + _, found := table.table[index] + if found { + table.mutex.Unlock() + continue + } + table.table[index] = IndexTableEntry{ + peer: peer, + handshake: &peer.handshake, + keyPair: nil, + } + table.mutex.Unlock() + return index, nil + } +} + +func (table *IndexTable) Lookup(id uint32) IndexTableEntry { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.table[id] +} diff --git a/ip.go b/ip.go new file mode 100644 index 0000000..752a404 --- /dev/null +++ b/ip.go @@ -0,0 +1,17 @@ +package main + +import ( + "net" +) + +const ( + IPv4offsetTotalLength = 2 + IPv4offsetSrc = 12 + IPv4offsetDst = IPv4offsetSrc + net.IPv4len +) + +const ( + IPv6offsetPayloadLength = 4 + IPv6offsetSrc = 8 + IPv6offsetDst = IPv6offsetSrc + net.IPv6len +) diff --git a/kdf_test.go b/kdf_test.go new file mode 100644 index 0000000..a89dacc --- /dev/null +++ b/kdf_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "encoding/hex" + "golang.org/x/crypto/blake2s" + "testing" +) + +type KDFTest struct { + key string + input string + t0 string + t1 string + t2 string +} + +func assertEquals(t *testing.T, a string, b string) { + if a != b { + t.Fatal("expected", a, "=", b) + } +} + +func TestKDF(t *testing.T) { + tests := []KDFTest{ + { + key: "746573742d6b6579", + input: "746573742d696e707574", + t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", + t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", + t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", + }, + { + key: "776972656775617264", + input: "776972656775617264", + t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", + t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", + t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", + }, + { + key: "", + input: "", + t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", + t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", + t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", + }, + } + + var t0, t1, t2 [blake2s.Size]byte + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF3(&t0, &t1, &t2, key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + t2s := hex.EncodeToString(t2[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + assertEquals(t, t2s, test.t2) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF2(&t0, &t1, key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF1(&t0, key, input) + t0s := hex.EncodeToString(t0[:]) + assertEquals(t, t0s, test.t0) + } +} diff --git a/keypair.go b/keypair.go new file mode 100644 index 0000000..283cb92 --- /dev/null +++ b/keypair.go @@ -0,0 +1,44 @@ +package main + +import ( + "crypto/cipher" + "sync" + "time" +) + +/* Due to limitations in Go and /x/crypto there is currently + * no way to ensure that key material is securely ereased in memory. + * + * Since this may harm the forward secrecy property, + * we plan to resolve this issue; whenever Go allows us to do so. + */ + +type KeyPair struct { + send cipher.AEAD + receive cipher.AEAD + replayFilter ReplayFilter + sendNonce uint64 + isInitiator bool + created time.Time + localIndex uint32 + remoteIndex uint32 +} + +type KeyPairs struct { + mutex sync.RWMutex + current *KeyPair + previous *KeyPair + next *KeyPair // not yet "confirmed by transport" +} + +func (kp *KeyPairs) Current() *KeyPair { + kp.mutex.RLock() + defer kp.mutex.RUnlock() + return kp.current +} + +func (device *Device) DeleteKeyPair(key *KeyPair) { + if key != nil { + device.indices.Delete(key.localIndex) + } +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..0872ef9 --- /dev/null +++ b/logger.go @@ -0,0 +1,50 @@ +package main + +import ( + "io" + "io/ioutil" + "log" + "os" +) + +const ( + LogLevelError = iota + LogLevelInfo + LogLevelDebug +) + +type Logger struct { + Debug *log.Logger + Info *log.Logger + Error *log.Logger +} + +func NewLogger(level int, prepend string) *Logger { + output := os.Stdout + logger := new(Logger) + + logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { + if level >= LogLevelDebug { + return output, output, output + } + if level >= LogLevelInfo { + return output, output, ioutil.Discard + } + return output, ioutil.Discard, ioutil.Discard + }() + + logger.Debug = log.New(logDebug, + "DEBUG: "+prepend, + log.Ldate|log.Ltime|log.Lshortfile, + ) + + logger.Info = log.New(logInfo, + "INFO: "+prepend, + log.Ldate|log.Ltime, + ) + logger.Error = log.New(logErr, + "ERROR: "+prepend, + log.Ldate|log.Ltime, + ) + return logger +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..b12bb09 --- /dev/null +++ b/main.go @@ -0,0 +1,196 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "runtime" + "strconv" +) + +const ( + ExitSetupSuccess = 0 + ExitSetupFailed = 1 +) + +const ( + ENV_WG_TUN_FD = "WG_TUN_FD" + ENV_WG_UAPI_FD = "WG_UAPI_FD" +) + +func printUsage() { + fmt.Printf("usage:\n") + fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) +} + +func main() { + + // parse arguments + + var foreground bool + var interfaceName string + if len(os.Args) < 2 || len(os.Args) > 3 { + printUsage() + return + } + + switch os.Args[1] { + + case "-f", "--foreground": + foreground = true + if len(os.Args) != 3 { + printUsage() + return + } + interfaceName = os.Args[2] + + default: + foreground = false + if len(os.Args) != 2 { + printUsage() + return + } + interfaceName = os.Args[1] + } + + // get log level (default: info) + + logLevel := func() int { + switch os.Getenv("LOG_LEVEL") { + case "debug": + return LogLevelDebug + case "info": + return LogLevelInfo + case "error": + return LogLevelError + } + return LogLevelInfo + }() + + logger := NewLogger( + logLevel, + fmt.Sprintf("(%s) ", interfaceName), + ) + + logger.Debug.Println("Debug log enabled") + + // open TUN device (or use supplied fd) + + tun, err := func() (TUNDevice, error) { + tunFdStr := os.Getenv(ENV_WG_TUN_FD) + if tunFdStr == "" { + return CreateTUN(interfaceName) + } + + // construct tun device from supplied fd + + fd, err := strconv.ParseUint(tunFdStr, 10, 32) + if err != nil { + return nil, err + } + + file := os.NewFile(uintptr(fd), "") + return CreateTUNFromFile(interfaceName, file) + }() + + if err != nil { + logger.Error.Println("Failed to create TUN device:", err) + os.Exit(ExitSetupFailed) + } + + // open UAPI file (or use supplied fd) + + fileUAPI, err := func() (*os.File, error) { + uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) + if uapiFdStr == "" { + return UAPIOpen(interfaceName) + } + + // use supplied fd + + fd, err := strconv.ParseUint(uapiFdStr, 10, 32) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(fd), ""), nil + }() + + if err != nil { + logger.Error.Println("UAPI listen error:", err) + os.Exit(ExitSetupFailed) + return + } + // daemonize the process + + if !foreground { + env := os.Environ() + env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) + env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) + attr := &os.ProcAttr{ + Files: []*os.File{ + nil, // stdin + nil, // stdout + nil, // stderr + tun.File(), + fileUAPI, + }, + Dir: ".", + Env: env, + } + err = Daemonize(attr) + if err != nil { + logger.Error.Println("Failed to daemonize:", err) + os.Exit(ExitSetupFailed) + } + return + } + + // increase number of go workers (for Go <1.5) + + runtime.GOMAXPROCS(runtime.NumCPU()) + + // create wireguard device + + device := NewDevice(tun, logger) + + logger.Info.Println("Device started") + + // start uapi listener + + errs := make(chan error) + term := make(chan os.Signal) + + uapi, err := UAPIListen(interfaceName, fileUAPI) + + go func() { + for { + conn, err := uapi.Accept() + if err != nil { + errs <- err + return + } + go ipcHandle(device, conn) + } + }() + + logger.Info.Println("UAPI listener started") + + // wait for program to terminate + + signal.Notify(term, os.Kill) + signal.Notify(term, os.Interrupt) + + select { + case <-term: + case <-errs: + case <-device.Wait(): + } + + // clean up + + uapi.Close() + device.Close() + + logger.Info.Println("Shutting down") +} diff --git a/misc.go b/misc.go new file mode 100644 index 0000000..80e33f6 --- /dev/null +++ b/misc.go @@ -0,0 +1,57 @@ +package main + +import ( + "sync/atomic" +) + +/* Atomic Boolean */ + +const ( + AtomicFalse = int32(iota) + AtomicTrue +) + +type AtomicBool struct { + flag int32 +} + +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.flag) == AtomicTrue +} + +func (a *AtomicBool) Swap(val bool) bool { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + return atomic.SwapInt32(&a.flag, flag) == AtomicTrue +} + +func (a *AtomicBool) Set(val bool) { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + atomic.StoreInt32(&a.flag, flag) +} + +/* Integer manipulation */ + +func toInt32(n uint32) int32 { + mask := uint32(1 << 31) + return int32(-(n & mask) + (n & ^mask)) +} + +func min(a uint, b uint) uint { + if a > b { + return b + } + return a +} + +func minUint64(a uint64, b uint64) uint64 { + if a > b { + return b + } + return a +} diff --git a/noise_helpers.go b/noise_helpers.go new file mode 100644 index 0000000..1e2de5f --- /dev/null +++ b/noise_helpers.go @@ -0,0 +1,98 @@ +package main + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/subtle" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" + "hash" +) + +/* KDF related functions. + * HMAC-based Key Derivation Function (HKDF) + * https://tools.ietf.org/html/rfc5869 + */ + +func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Sum(sum[:0]) +} + +func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Write(in1) + mac.Sum(sum[:0]) +} + +func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { + HMAC1(t0, key, input) + HMAC1(t0, t0[:], []byte{0x1}) + return +} + +func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { + var prk [blake2s.Size]byte + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + setZero(prk[:]) + return +} + +func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { + var prk [blake2s.Size]byte + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + HMAC2(t2, prk[:], t1[:], []byte{0x3}) + setZero(prk[:]) + return +} + +func isZero(val []byte) bool { + acc := 1 + for _, b := range val { + acc &= subtle.ConstantTimeByteEq(b, 0) + } + return acc == 1 +} + +func setZero(arr []byte) { + for i := range arr { + arr[i] = 0 + } +} + +/* curve25519 wrappers */ + +func newPrivateKey() (sk NoisePrivateKey, err error) { + // clamping: https://cr.yp.to/ecdh.html + _, err = rand.Read(sk[:]) + sk[0] &= 248 + sk[31] &= 127 + sk[31] |= 64 + return +} + +func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} diff --git a/noise_protocol.go b/noise_protocol.go new file mode 100644 index 0000000..c9713c0 --- /dev/null +++ b/noise_protocol.go @@ -0,0 +1,578 @@ +package main + +import ( + "errors" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" + "sync" + "time" +) + +const ( + HandshakeZeroed = iota + HandshakeInitiationCreated + HandshakeInitiationConsumed + HandshakeResponseCreated + HandshakeResponseConsumed +) + +const ( + NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" + WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" + WGLabelMAC1 = "mac1----" + WGLabelCookie = "cookie--" +) + +const ( + MessageInitiationType = 1 + MessageResponseType = 2 + MessageCookieReplyType = 3 + MessageTransportType = 4 +) + +const ( + MessageInitiationSize = 148 // size of handshake initation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceeding content in transport message + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message +) + +const ( + MessageTransportOffsetReceiver = 4 + MessageTransportOffsetCounter = 8 + MessageTransportOffsetContent = 16 +) + +/* Type is an 8-bit field, followed by 3 nul bytes, + * by marshalling the messages in little-endian byteorder + * we can treat these as a 32-bit unsigned int (for now) + * + */ + +type MessageInitiation struct { + Type uint32 + Sender uint32 + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte + Timestamp [TAI64NSize + poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageResponse struct { + Type uint32 + Sender uint32 + Receiver uint32 + Ephemeral NoisePublicKey + Empty [poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageTransport struct { + Type uint32 + Receiver uint32 + Counter uint64 + Content []byte +} + +type MessageCookieReply struct { + Type uint32 + Receiver uint32 + Nonce [24]byte + Cookie [blake2s.Size128 + poly1305.TagSize]byte +} + +type Handshake struct { + state int + mutex sync.RWMutex + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + lastTimestamp TAI64N + lastInitiationConsumption time.Time +} + +var ( + InitialChainKey [blake2s.Size]byte + InitialHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte +) + +func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { + KDF1(dst, c[:], data) +} + +func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { + hsh, _ := blake2s.New256(nil) + hsh.Write(h[:]) + hsh.Write(data) + hsh.Sum(dst[:0]) + hsh.Reset() +} + +func (h *Handshake) Clear() { + setZero(h.localEphemeral[:]) + setZero(h.remoteEphemeral[:]) + setZero(h.chainKey[:]) + setZero(h.hash[:]) + h.localIndex = 0 + h.state = HandshakeZeroed +} + +func (h *Handshake) mixHash(data []byte) { + mixHash(&h.hash, &h.hash, data) +} + +func (h *Handshake) mixKey(data []byte) { + mixKey(&h.chainKey, &h.chainKey, data) +} + +/* Do basic precomputations + */ +func init() { + InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) + mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) +} + +func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errors.New("Static shared secret is zero") + } + + // create ephemeral key + + var err error + handshake.hash = InitialHash + handshake.chainKey = InitialChainKey + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + + // assign index + + device.indices.Delete(handshake.localIndex) + handshake.localIndex, err = device.indices.NewIndex(peer) + + if err != nil { + return nil, err + } + + handshake.mixHash(handshake.remoteStatic[:]) + + msg := MessageInitiation{ + Type: MessageInitiationType, + Ephemeral: handshake.localEphemeral.publicKey(), + Sender: handshake.localIndex, + } + + handshake.mixKey(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) + + // encrypt static key + + func() { + var key [chacha20poly1305.KeySize]byte + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) + }() + handshake.mixHash(msg.Static[:]) + + // encrypt timestamp + + timestamp := Timestamp() + func() { + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + }() + + handshake.mixHash(msg.Timestamp[:]) + handshake.state = HandshakeInitiationCreated + return &msg, nil +} + +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + if msg.Type != MessageInitiationType { + return nil + } + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + mixHash(&hash, &InitialHash, device.noise.publicKey[:]) + mixHash(&hash, &hash, msg.Ephemeral[:]) + mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) + + // decrypt static key + + var err error + var peerPK NoisePublicKey + func() { + var key [chacha20poly1305.KeySize]byte + ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) + }() + if err != nil { + return nil + } + mixHash(&hash, &hash, msg.Static[:]) + + // lookup peer + + peer := device.LookupPeer(peerPK) + if peer == nil { + return nil + } + + handshake := &peer.handshake + if isZero(handshake.precomputedStaticStatic[:]) { + return nil + } + + // verify identity + + var timestamp TAI64N + var key [chacha20poly1305.KeySize]byte + + handshake.mutex.RLock() + KDF2( + &chainKey, + &key, + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + if err != nil { + handshake.mutex.RUnlock() + return nil + } + mixHash(&hash, &hash, msg.Timestamp[:]) + + // protect against replay & flood + + var ok bool + ok = timestamp.After(handshake.lastTimestamp) + ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate + handshake.mutex.RUnlock() + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.remoteEphemeral = msg.Ephemeral + handshake.lastTimestamp = timestamp + handshake.lastInitiationConsumption = time.Now() + handshake.state = HandshakeInitiationConsumed + + handshake.mutex.Unlock() + + return peer +} + +func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if handshake.state != HandshakeInitiationConsumed { + return nil, errors.New("handshake initation must be consumed first") + } + + // assign index + + var err error + device.indices.Delete(handshake.localIndex) + handshake.localIndex, err = device.indices.NewIndex(peer) + if err != nil { + return nil, err + } + + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender = handshake.localIndex + msg.Receiver = handshake.remoteIndex + + // create ephemeral key + + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + msg.Ephemeral = handshake.localEphemeral.publicKey() + handshake.mixHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + handshake.mixKey(ss[:]) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.mixKey(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + + KDF3( + &handshake.chainKey, + &tau, + &key, + handshake.chainKey[:], + handshake.presharedKey[:], + ) + + handshake.mixHash(tau[:]) + + func() { + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) + }() + + handshake.state = HandshakeResponseCreated + + return &msg, nil +} + +func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { + if msg.Type != MessageResponseType { + return nil + } + + // lookup handshake by reciever + + lookup := device.indices.Lookup(msg.Receiver) + handshake := lookup.handshake + if handshake == nil { + return nil + } + + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + ok := func() bool { + + // lock handshake state + + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() + + if handshake.state != HandshakeInitiationCreated { + return false + } + + // lock private key for reading + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + // finish 3-way DH + + mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) + mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + func() { + ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + KDF3( + &chainKey, + &tau, + &key, + chainKey[:], + handshake.presharedKey[:], + ) + mixHash(&hash, &hash, tau[:]) + + // authenticate transcript + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + device.log.Debug.Println("failed to open") + return false + } + mixHash(&hash, &hash, msg.Empty[:]) + return true + }() + + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.state = HandshakeResponseConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return lookup.peer +} + +/* Derives a new key-pair from the current handshake state + * + */ +func (peer *Peer) NewKeyPair() *KeyPair { + device := peer.device + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // derive keys + + var isInitiator bool + var sendKey [chacha20poly1305.KeySize]byte + var recvKey [chacha20poly1305.KeySize]byte + + if handshake.state == HandshakeResponseConsumed { + KDF2( + &sendKey, + &recvKey, + handshake.chainKey[:], + nil, + ) + isInitiator = true + } else if handshake.state == HandshakeResponseCreated { + KDF2( + &recvKey, + &sendKey, + handshake.chainKey[:], + nil, + ) + isInitiator = false + } else { + return nil + } + + // zero handshake + + setZero(handshake.chainKey[:]) + setZero(handshake.localEphemeral[:]) + peer.handshake.state = HandshakeZeroed + + // create AEAD instances + + keyPair := new(KeyPair) + keyPair.send, _ = chacha20poly1305.New(sendKey[:]) + keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) + + setZero(sendKey[:]) + setZero(recvKey[:]) + + keyPair.created = time.Now() + keyPair.sendNonce = 0 + keyPair.replayFilter.Init() + keyPair.isInitiator = isInitiator + keyPair.localIndex = peer.handshake.localIndex + keyPair.remoteIndex = peer.handshake.remoteIndex + + // remap index + + device.indices.Insert( + handshake.localIndex, + IndexTableEntry{ + peer: peer, + keyPair: keyPair, + handshake: nil, + }, + ) + handshake.localIndex = 0 + + // rotate key pairs + + kp := &peer.keyPairs + kp.mutex.Lock() + + if isInitiator { + if kp.previous != nil { + device.DeleteKeyPair(kp.previous) + kp.previous = nil + } + + if kp.next != nil { + kp.previous = kp.next + kp.next = keyPair + } else { + kp.previous = kp.current + kp.current = keyPair + peer.signal.newKeyPair.Send() + } + + } else { + kp.next = keyPair + kp.previous = nil + } + kp.mutex.Unlock() + + return keyPair +} diff --git a/noise_test.go b/noise_test.go new file mode 100644 index 0000000..5e9d44b --- /dev/null +++ b/noise_test.go @@ -0,0 +1,136 @@ +package main + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestCurveWrappers(t *testing.T) { + sk1, err := newPrivateKey() + assertNil(t, err) + + sk2, err := newPrivateKey() + assertNil(t, err) + + pk1 := sk1.publicKey() + pk2 := sk2.publicKey() + + ss1 := sk1.sharedSecret(pk2) + ss2 := sk2.sharedSecret(pk1) + + if ss1 != ss2 { + t.Fatal("Failed to compute shared secet") + } +} + +func TestNoiseHandshake(t *testing.T) { + dev1 := randDevice(t) + dev2 := randDevice(t) + + defer dev1.Close() + defer dev2.Close() + + peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey()) + + assertEqual( + t, + peer1.handshake.precomputedStaticStatic[:], + peer2.handshake.precomputedStaticStatic[:], + ) + + /* simulate handshake */ + + // initiation message + + t.Log("exchange initiation message") + + msg1, err := dev1.CreateMessageInitiation(peer2) + assertNil(t, err) + + packet := make([]byte, 0, 256) + writer := bytes.NewBuffer(packet) + err = binary.Write(writer, binary.LittleEndian, msg1) + peer := dev2.ConsumeMessageInitiation(msg1) + if peer == nil { + t.Fatal("handshake failed at initiation message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // response message + + t.Log("exchange response message") + + msg2, err := dev2.CreateMessageResponse(peer1) + assertNil(t, err) + + peer = dev1.ConsumeMessageResponse(msg2) + if peer == nil { + t.Fatal("handshake failed at response message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // key pairs + + t.Log("deriving keys") + + key1 := peer1.NewKeyPair() + key2 := peer2.NewKeyPair() + + if key1 == nil { + t.Fatal("failed to dervice key-pair for peer 1") + } + + if key2 == nil { + t.Fatal("failed to dervice key-pair for peer 2") + } + + // encrypting / decryption test + + t.Log("test key pairs") + + func() { + testMsg := []byte("wireguard test message 1") + var err error + var out []byte + var nonce [12]byte + out = key1.send.Seal(out, nonce[:], testMsg, nil) + out, err = key2.receive.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() + + func() { + testMsg := []byte("wireguard test message 2") + var err error + var out []byte + var nonce [12]byte + out = key2.send.Seal(out, nonce[:], testMsg, nil) + out, err = key1.receive.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() +} diff --git a/noise_types.go b/noise_types.go new file mode 100644 index 0000000..1a944df --- /dev/null +++ b/noise_types.go @@ -0,0 +1,74 @@ +package main + +import ( + "crypto/subtle" + "encoding/hex" + "errors" + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 +) + +type ( + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte + NoiseSymmetricKey [chacha20poly1305.KeySize]byte + NoiseNonce uint64 // padded to 12-bytes +) + +func loadExactHex(dst []byte, src string) error { + slice, err := hex.DecodeString(src) + if err != nil { + return err + } + if len(slice) != len(dst) { + return errors.New("Hex string does not fit the slice") + } + copy(dst, slice) + return nil +} + +func (key NoisePrivateKey) IsZero() bool { + var zero NoisePrivateKey + return key.Equals(zero) +} + +func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoisePrivateKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePrivateKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key *NoisePublicKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePublicKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key NoisePublicKey) IsZero() bool { + var zero NoisePublicKey + return key.Equals(zero) +} + +func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoiseSymmetricKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoiseSymmetricKey) ToHex() string { + return hex.EncodeToString(key[:]) +} diff --git a/peer.go b/peer.go new file mode 100644 index 0000000..dc04811 --- /dev/null +++ b/peer.go @@ -0,0 +1,295 @@ +package main + +import ( + "encoding/base64" + "errors" + "fmt" + "github.com/sasha-s/go-deadlock" + "sync" + "time" +) + +const ( + PeerRoutineNumber = 4 +) + +type Peer struct { + isRunning AtomicBool + mutex deadlock.RWMutex + persistentKeepaliveInterval uint64 + keyPairs KeyPairs + handshake Handshake + device *Device + endpoint Endpoint + + stats struct { + txBytes uint64 // bytes send to peer (endpoint) + rxBytes uint64 // bytes received from peer + lastHandshakeNano int64 // nano seconds since epoch + } + + time struct { + mutex deadlock.RWMutex + lastSend time.Time // last send message + lastHandshake time.Time // last completed handshake + nextKeepalive time.Time + } + + signal struct { + newKeyPair Signal // size 1, new key pair was generated + handshakeCompleted Signal // size 1, handshake completed + handshakeBegin Signal // size 1, begin new handshake begin + flushNonceQueue Signal // size 1, empty queued packets + messageSend Signal // size 1, message was send to peer + messageReceived Signal // size 1, authenticated message recv + } + + timer struct { + + // state related to WireGuard timers + + keepalivePersistent Timer // set for persistent keep-alive + keepalivePassive Timer // set upon receiving messages + zeroAllKeys Timer // zero all key material + handshakeNew Timer // begin a new handshake (stale) + handshakeDeadline Timer // complete handshake timeout + handshakeTimeout Timer // current handshake message timeout + + sendLastMinuteHandshake bool + needAnotherKeepalive bool + } + + queue struct { + nonce chan *QueueOutboundElement // nonce / pre-handshake queue + outbound chan *QueueOutboundElement // sequential ordering of work + inbound chan *QueueInboundElement // sequential ordering of work + } + + routines struct { + mutex deadlock.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop Signal // size 0, stop all go-routines in peer + } + + mac CookieGenerator +} + +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + + if device.isClosed.Get() { + return nil, errors.New("Device closed") + } + + // lock resources + + device.state.mutex.Lock() + defer device.state.mutex.Unlock() + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // check if over limit + + if len(device.peers.keyMap) >= MaxPeers { + return nil, errors.New("Too many peers") + } + + // create peer + + peer := new(Peer) + peer.mutex.Lock() + defer peer.mutex.Unlock() + + peer.mac.Init(pk) + peer.device = device + peer.isRunning.Set(false) + + peer.timer.zeroAllKeys = NewTimer() + peer.timer.keepalivePersistent = NewTimer() + peer.timer.keepalivePassive = NewTimer() + peer.timer.handshakeNew = NewTimer() + peer.timer.handshakeDeadline = NewTimer() + peer.timer.handshakeTimeout = NewTimer() + + // map public key + + _, ok := device.peers.keyMap[pk] + if ok { + return nil, errors.New("Adding existing peer") + } + device.peers.keyMap[pk] = peer + + // pre-compute DH + + handshake := &peer.handshake + handshake.mutex.Lock() + handshake.remoteStatic = pk + handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) + handshake.mutex.Unlock() + + // reset endpoint + + peer.endpoint = nil + + // prepare signaling & routines + + peer.routines.mutex.Lock() + peer.routines.stop = NewSignal() + peer.routines.mutex.Unlock() + + // start peer + + if peer.device.isUp.Get() { + peer.Start() + } + + return peer, nil +} + +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.mutex.RLock() + defer peer.device.net.mutex.RUnlock() + + if peer.device.net.bind == nil { + return errors.New("No bind") + } + + peer.mutex.RLock() + defer peer.mutex.RUnlock() + + if peer.endpoint == nil { + return errors.New("No known endpoint for peer") + } + + return peer.device.net.bind.Send(buffer, peer.endpoint) +} + +/* Returns a short string identifier for logging + */ +func (peer *Peer) String() string { + if peer.endpoint == nil { + return fmt.Sprintf( + "peer(unknown %s)", + base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), + ) + } + return fmt.Sprintf( + "peer(%s %s)", + peer.endpoint.DstToString(), + base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), + ) +} + +func (peer *Peer) Start() { + + // should never start a peer on a closed device + + if peer.device.isClosed.Get() { + return + } + + // prevent simultaneous start/stop operations + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() + + if peer.isRunning.Get() { + return + } + + peer.device.log.Debug.Println("Starting:", peer.String()) + + // sanity check : these should be 0 + + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // prepare queues and signals + + peer.signal.newKeyPair = NewSignal() + peer.signal.handshakeBegin = NewSignal() + peer.signal.handshakeCompleted = NewSignal() + peer.signal.flushNonceQueue = NewSignal() + + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + + peer.routines.stop = NewSignal() + peer.isRunning.Set(true) + + // wait for routines to start + + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) + + go peer.RoutineNonce() + go peer.RoutineTimerHandler() + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + + peer.routines.starting.Wait() + peer.isRunning.Set(true) +} + +func (peer *Peer) Stop() { + + // prevent simultaneous start/stop operations + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() + + if !peer.isRunning.Swap(false) { + return + } + + device := peer.device + device.log.Debug.Println("Stopping:", peer.String()) + + // stop & wait for ongoing peer routines + + peer.routines.stop.Broadcast() + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // stop timers + + peer.timer.keepalivePersistent.Stop() + peer.timer.keepalivePassive.Stop() + peer.timer.zeroAllKeys.Stop() + peer.timer.handshakeNew.Stop() + peer.timer.handshakeDeadline.Stop() + peer.timer.handshakeTimeout.Stop() + + // close queues + + close(peer.queue.nonce) + close(peer.queue.outbound) + close(peer.queue.inbound) + + // clear key pairs + + kp := &peer.keyPairs + kp.mutex.Lock() + + device.DeleteKeyPair(kp.previous) + device.DeleteKeyPair(kp.current) + device.DeleteKeyPair(kp.next) + + kp.previous = nil + kp.current = nil + kp.next = nil + kp.mutex.Unlock() + + // clear handshake state + + hs := &peer.handshake + hs.mutex.Lock() + device.indices.Delete(hs.localIndex) + hs.Clear() + hs.mutex.Unlock() +} diff --git a/ratelimiter.go b/ratelimiter.go new file mode 100644 index 0000000..6e5f005 --- /dev/null +++ b/ratelimiter.go @@ -0,0 +1,139 @@ +package main + +/* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ + +/* This file contains a port of the ratelimited from the linux kernel version + */ + +import ( + "net" + "sync" + "time" +) + +const ( + RatelimiterPacketsPerSecond = 20 + RatelimiterPacketsBurstable = 5 + RatelimiterGarbageCollectTime = time.Second + RatelimiterPacketCost = 1000000000 / RatelimiterPacketsPerSecond + RatelimiterMaxTokens = RatelimiterPacketCost * RatelimiterPacketsBurstable +) + +type RatelimiterEntry struct { + mutex sync.Mutex + lastTime time.Time + tokens int64 +} + +type Ratelimiter struct { + mutex sync.RWMutex + lastGarbageCollect time.Time + tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry + tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry +} + +func (rate *Ratelimiter) Init() { + rate.mutex.Lock() + defer rate.mutex.Unlock() + rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) + rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.lastGarbageCollect = time.Now() +} + +func (rate *Ratelimiter) GarbageCollectEntries() { + rate.mutex.Lock() + + // remove unused IPv4 entries + + for key, entry := range rate.tableIPv4 { + entry.mutex.Lock() + if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { + delete(rate.tableIPv4, key) + } + entry.mutex.Unlock() + } + + // remove unused IPv6 entries + + for key, entry := range rate.tableIPv6 { + entry.mutex.Lock() + if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { + delete(rate.tableIPv6, key) + } + entry.mutex.Unlock() + } + + rate.mutex.Unlock() +} + +func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) { + timer := time.NewTimer(time.Second) + for { + select { + case <-stop.Wait(): + return + case <-timer.C: + rate.GarbageCollectEntries() + timer.Reset(time.Second) + } + } +} + +func (rate *Ratelimiter) Allow(ip net.IP) bool { + var entry *RatelimiterEntry + var KeyIPv4 [net.IPv4len]byte + var KeyIPv6 [net.IPv6len]byte + + // lookup entry + + IPv4 := ip.To4() + IPv6 := ip.To16() + + rate.mutex.RLock() + + if IPv4 != nil { + copy(KeyIPv4[:], IPv4) + entry = rate.tableIPv4[KeyIPv4] + } else { + copy(KeyIPv6[:], IPv6) + entry = rate.tableIPv6[KeyIPv6] + } + + rate.mutex.RUnlock() + + // make new entry if not found + + if entry == nil { + rate.mutex.Lock() + entry = new(RatelimiterEntry) + entry.tokens = RatelimiterMaxTokens - RatelimiterPacketCost + entry.lastTime = time.Now() + if IPv4 != nil { + rate.tableIPv4[KeyIPv4] = entry + } else { + rate.tableIPv6[KeyIPv6] = entry + } + rate.mutex.Unlock() + return true + } + + // add tokens to entry + + entry.mutex.Lock() + now := time.Now() + entry.tokens += now.Sub(entry.lastTime).Nanoseconds() + entry.lastTime = now + if entry.tokens > RatelimiterMaxTokens { + entry.tokens = RatelimiterMaxTokens + } + + // subtract cost of packet + + if entry.tokens > RatelimiterPacketCost { + entry.tokens -= RatelimiterPacketCost + entry.mutex.Unlock() + return true + } + entry.mutex.Unlock() + return false +} diff --git a/ratelimiter_test.go b/ratelimiter_test.go new file mode 100644 index 0000000..13b6a23 --- /dev/null +++ b/ratelimiter_test.go @@ -0,0 +1,98 @@ +package main + +import ( + "net" + "testing" + "time" +) + +type RatelimiterResult struct { + allowed bool + text string + wait time.Duration +} + +func TestRatelimiter(t *testing.T) { + + var ratelimiter Ratelimiter + var expectedResults []RatelimiterResult + + Nano := func(nano int64) time.Duration { + return time.Nanosecond * time.Duration(nano) + } + + Add := func(res RatelimiterResult) { + expectedResults = append( + expectedResults, + res, + ) + } + + for i := 0; i < RatelimiterPacketsBurstable; i++ { + Add(RatelimiterResult{ + allowed: true, + text: "inital burst", + }) + } + + Add(RatelimiterResult{ + allowed: false, + text: "after burst", + }) + + Add(RatelimiterResult{ + allowed: true, + wait: Nano(time.Second.Nanoseconds() / RatelimiterPacketsPerSecond), + text: "filling tokens for single packet", + }) + + Add(RatelimiterResult{ + allowed: false, + text: "not having refilled enough", + }) + + Add(RatelimiterResult{ + allowed: true, + wait: 2 * Nano(time.Second.Nanoseconds()/RatelimiterPacketsPerSecond), + text: "filling tokens for two packet burst", + }) + + Add(RatelimiterResult{ + allowed: true, + text: "second packet in 2 packet burst", + }) + + Add(RatelimiterResult{ + allowed: false, + text: "packet following 2 packet burst", + }) + + ips := []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("192.168.1.1"), + net.ParseIP("172.167.2.3"), + net.ParseIP("97.231.252.215"), + net.ParseIP("248.97.91.167"), + net.ParseIP("188.208.233.47"), + net.ParseIP("104.2.183.179"), + net.ParseIP("72.129.46.120"), + net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + } + + ratelimiter.Init() + + for i, res := range expectedResults { + time.Sleep(res.wait) + for _, ip := range ips { + allowed := ratelimiter.Allow(ip) + if allowed != res.allowed { + t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) + } + } + } +} diff --git a/receive.go b/receive.go new file mode 100644 index 0000000..1f44df2 --- /dev/null +++ b/receive.go @@ -0,0 +1,642 @@ +package main + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "sync" + "sync/atomic" + "time" +) + +type QueueHandshakeElement struct { + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte +} + +type QueueInboundElement struct { + dropped int32 + mutex sync.Mutex + buffer *[MaxMessageSize]byte + packet []byte + counter uint64 + keyPair *KeyPair + endpoint Endpoint +} + +func (elem *QueueInboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueInboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func (device *Device) addToInboundQueue( + queue chan *QueueInboundElement, + element *QueueInboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + old.Drop() + default: + } + } + } +} + +func (device *Device) addToDecryptionQueue( + queue chan *QueueInboundElement, + element *QueueInboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + // drop & release to potential consumer + old.Drop() + old.mutex.Unlock() + default: + } + } + } +} + +func (device *Device) addToHandshakeQueue( + queue chan QueueHandshakeElement, + element QueueHandshakeElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case elem := <-queue: + device.PutMessageBuffer(elem.buffer) + default: + } + } + } +} + +/* Receives incoming datagrams for the device + * + * Every time the bind is updated a new routine is started for + * IPv4 and IPv6 (separately) + */ +func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { + + logDebug := device.log.Debug + logDebug.Println("Routine, receive incoming, IP version:", IP) + + // receive datagrams until conn is closed + + buffer := device.GetMessageBuffer() + + var ( + err error + size int + endpoint Endpoint + ) + + for { + + // read next datagram + + switch IP { + case ipv4.Version: + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) + case ipv6.Version: + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) + default: + panic("invalid IP version") + } + + if err != nil { + return + } + + if size < MinMessageSize { + continue + } + + // check size of packet + + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) + + var okay bool + + switch msgType { + + // check if transport + + case MessageTransportType: + + // check size + + if len(packet) < MessageTransportType { + continue + } + + // lookup key pair + + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indices.Lookup(receiver) + keyPair := value.keyPair + if keyPair == nil { + continue + } + + // check key-pair expiry + + if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } + + // create work element + + peer := value.peer + elem := &QueueInboundElement{ + packet: packet, + buffer: buffer, + keyPair: keyPair, + dropped: AtomicFalse, + endpoint: endpoint, + } + elem.mutex.Lock() + + // add to decryption queues + + if peer.isRunning.Get() { + device.addToDecryptionQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) + buffer = device.GetMessageBuffer() + } + + continue + + // otherwise it is a fixed size & handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + } + + if okay { + device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + ) + buffer = device.GetMessageBuffer() + } + } +} + +func (device *Device) RoutineDecryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + logDebug.Println("Routine, decryption, started for device") + + for { + select { + case <-device.signal.stop.Wait(): + logDebug.Println("Routine, decryption worker, stopped") + return + + case elem := <-device.queue.decryption: + + // check if dropped + + if elem.IsDropped() { + continue + } + + // split message into fields + + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // expand nonce + + nonce[0x4] = counter[0x0] + nonce[0x5] = counter[0x1] + nonce[0x6] = counter[0x2] + nonce[0x7] = counter[0x3] + + nonce[0x8] = counter[0x4] + nonce[0x9] = counter[0x5] + nonce[0xa] = counter[0x6] + nonce[0xb] = counter[0x7] + + // decrypt and release to consumer + + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + elem.packet, err = elem.keyPair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + } + elem.mutex.Unlock() + } + } +} + +/* Handles incoming packets related to handshake + */ +func (device *Device) RoutineHandshake() { + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, handshake routine, started for device") + + var temp [MessageHandshakeSize]byte + var elem QueueHandshakeElement + + for { + select { + case elem = <-device.queue.handshake: + case <-device.signal.stop.Wait(): + return + } + + // handle cookie fields and ratelimiting + + switch elem.msgType { + + case MessageCookieReplyType: + + // unmarshal packet + + var reply MessageCookieReply + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + logDebug.Println("Failed to decode cookie reply") + return + } + + // lookup peer from index + + entry := device.indices.Lookup(reply.Receiver) + + if entry.peer == nil { + continue + } + + // consume reply + + if peer := entry.peer; peer.isRunning.Get() { + peer.mac.ConsumeReply(&reply) + } + + continue + + case MessageInitiationType, MessageResponseType: + + // check mac fields and ratelimit + + if !device.mac.CheckMAC1(elem.packet) { + logDebug.Println("Received packet with invalid mac1") + continue + } + + // endpoints destination address is the source of the datagram + + srcBytes := elem.endpoint.DstToBytes() + + if device.IsUnderLoad() { + + // verify MAC2 field + + if !device.mac.CheckMAC2(elem.packet, srcBytes) { + + // construct cookie reply + + logDebug.Println( + "Sending cookie reply to:", + elem.endpoint.DstToString(), + ) + + sender := binary.LittleEndian.Uint32(elem.packet[4:8]) + reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) + if err != nil { + logError.Println("Failed to create cookie reply:", err) + continue + } + + // marshal and send reply + + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, reply) + device.net.bind.Send(writer.Bytes(), elem.endpoint) + if err != nil { + logDebug.Println("Failed to send cookie reply:", err) + } + continue + } + + // check ratelimiter + + if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { + continue + } + } + + default: + logError.Println("Invalid packet ended up in the handshake queue") + continue + } + + // handle handshake initiation/response content + + switch elem.msgType { + case MessageInitiationType: + + // unmarshal + + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + continue + } + + // consume initiation + + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Received invalid initiation message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + + // update endpoint + + peer.mutex.Lock() + peer.endpoint = elem.endpoint + peer.mutex.Unlock() + + // create response + + response, err := device.CreateMessageResponse(peer) + if err != nil { + logError.Println("Failed to create response message:", err) + continue + } + + peer.TimerEphemeralKeyCreated() + peer.NewKeyPair() + + logDebug.Println("Creating response message for", peer.String()) + + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, response) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + // send response + + err = peer.SendBuffer(packet) + if err == nil { + peer.TimerAnyAuthenticatedPacketTraversal() + } else { + logError.Println("Failed to send response to:", peer.String(), err) + } + + case MessageResponseType: + + // unmarshal + + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + continue + } + + // consume response + + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid response message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update endpoint + + peer.mutex.Lock() + peer.endpoint = elem.endpoint + peer.mutex.Unlock() + + logDebug.Println("Received handshake initiation from", peer) + + peer.TimerEphemeralKeyCreated() + + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.TimerHandshakeComplete() + + // derive key-pair + + peer.NewKeyPair() + peer.SendKeepAlive() + } + } +} + +func (peer *Peer) RoutineSequentialReceiver() { + + defer peer.routines.stopping.Done() + + device := peer.device + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, sequential receiver, started for peer", peer.String()) + + peer.routines.starting.Done() + + for { + + select { + + case <-peer.routines.stop.Wait(): + logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String()) + return + + case elem := <-peer.queue.inbound: + + // wait for decryption + + elem.mutex.Lock() + + if elem.IsDropped() { + continue + } + + // check for replay + + if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { + continue + } + + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.KeepKeyFreshReceiving() + + // check if using new key-pair + + kp := &peer.keyPairs + kp.mutex.Lock() + if kp.next == elem.keyPair { + peer.TimerHandshakeComplete() + if kp.previous != nil { + device.DeleteKeyPair(kp.previous) + } + kp.previous = kp.current + kp.current = kp.next + kp.next = nil + } + kp.mutex.Unlock() + + // update endpoint + + peer.mutex.Lock() + peer.endpoint = elem.endpoint + peer.mutex.Unlock() + + // check for keep-alive + + if len(elem.packet) == 0 { + logDebug.Println("Received keep-alive from", peer.String()) + continue + } + peer.TimerDataReceived() + + // verify source and strip padding + + switch elem.packet[0] >> 4 { + case ipv4.Version: + + // strip padding + + if len(elem.packet) < ipv4.HeaderLen { + continue + } + + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv4 source + + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.routing.table.LookupIPv4(src) != peer { + logInfo.Println( + "IPv4 packet with disallowed source address from", + peer.String(), + ) + continue + } + + case ipv6.Version: + + // strip padding + + if len(elem.packet) < ipv6.HeaderLen { + continue + } + + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv6 source + + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.routing.table.LookupIPv6(src) != peer { + logInfo.Println( + "IPv6 packet with disallowed source address from", + peer.String(), + ) + continue + } + + default: + logInfo.Println("Packet with invalid IP version from", peer.String()) + continue + } + + // write to tun device + + offset := MessageTransportOffsetContent + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write( + elem.buffer[:offset+len(elem.packet)], + offset) + device.PutMessageBuffer(elem.buffer) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) + } + } + } +} diff --git a/replay.go b/replay.go new file mode 100644 index 0000000..5d42860 --- /dev/null +++ b/replay.go @@ -0,0 +1,73 @@ +package main + +/* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ + +/* Implementation of RFC6479 + * https://tools.ietf.org/html/rfc6479 + * + * The implementation is not safe for concurrent use! + */ + +const ( + // See: https://golang.org/src/math/big/arith.go + _Wordm = ^uintptr(0) + _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 + _WordSize = 1 << _WordLogSize +) + +const ( + CounterRedundantBitsLog = _WordLogSize + 3 + CounterRedundantBits = _WordSize * 8 + CounterBitsTotal = 2048 + CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) +) + +const ( + BacktrackWords = CounterBitsTotal / _WordSize +) + +type ReplayFilter struct { + counter uint64 + backtrack [BacktrackWords]uintptr +} + +func (filter *ReplayFilter) Init() { + filter.counter = 0 + filter.backtrack[0] = 0 +} + +func (filter *ReplayFilter) ValidateCounter(counter uint64) bool { + if counter >= RejectAfterMessages { + return false + } + + indexWord := counter >> CounterRedundantBitsLog + + if counter > filter.counter { + + // move window forward + + current := filter.counter >> CounterRedundantBitsLog + diff := minUint64(indexWord-current, BacktrackWords) + for i := uint64(1); i <= diff; i++ { + filter.backtrack[(current+i)%BacktrackWords] = 0 + } + filter.counter = counter + + } else if filter.counter-counter > CounterWindowSize { + + // behind current window + + return false + } + + indexWord %= BacktrackWords + indexBit := counter & uint64(CounterRedundantBits-1) + + // check and set bit + + oldValue := filter.backtrack[indexWord] + newValue := oldValue | (1 << indexBit) + filter.backtrack[indexWord] = newValue + return oldValue != newValue +} diff --git a/replay_test.go b/replay_test.go new file mode 100644 index 0000000..228fce6 --- /dev/null +++ b/replay_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "testing" +) + +/* Ported from the linux kernel implementation + * + * + */ + +func TestReplay(t *testing.T) { + var filter ReplayFilter + + T_LIM := CounterWindowSize + 1 + + testNumber := 0 + T := func(n uint64, v bool) { + testNumber++ + if filter.ValidateCounter(n) != v { + t.Fatal("Test", testNumber, "failed", n, v) + } + } + + filter.Init() + + /* 1 */ T(0, true) + /* 2 */ T(1, true) + /* 3 */ T(1, false) + /* 4 */ T(9, true) + /* 5 */ T(8, true) + /* 6 */ T(7, true) + /* 7 */ T(7, false) + /* 8 */ T(T_LIM, true) + /* 9 */ T(T_LIM-1, true) + /* 10 */ T(T_LIM-1, false) + /* 11 */ T(T_LIM-2, true) + /* 12 */ T(2, true) + /* 13 */ T(2, false) + /* 14 */ T(T_LIM+16, true) + /* 15 */ T(3, false) + /* 16 */ T(T_LIM+16, false) + /* 17 */ T(T_LIM*4, true) + /* 18 */ T(T_LIM*4-(T_LIM-1), true) + /* 19 */ T(10, false) + /* 20 */ T(T_LIM*4-T_LIM, false) + /* 21 */ T(T_LIM*4-(T_LIM+1), false) + /* 22 */ T(T_LIM*4-(T_LIM-2), true) + /* 23 */ T(T_LIM*4+1-T_LIM, false) + /* 24 */ T(0, false) + /* 25 */ T(RejectAfterMessages, false) + /* 26 */ T(RejectAfterMessages-1, true) + /* 27 */ T(RejectAfterMessages, false) + /* 28 */ T(RejectAfterMessages-1, false) + /* 29 */ T(RejectAfterMessages-2, true) + /* 30 */ T(RejectAfterMessages+1, false) + /* 31 */ T(RejectAfterMessages+2, false) + /* 32 */ T(RejectAfterMessages-2, false) + /* 33 */ T(RejectAfterMessages-3, true) + /* 34 */ T(0, false) + + t.Log("Bulk test 1") + filter.Init() + testNumber = 0 + for i := uint64(1); i <= CounterWindowSize; i++ { + T(i, true) + } + T(0, true) + T(0, false) + + t.Log("Bulk test 2") + filter.Init() + testNumber = 0 + for i := uint64(2); i <= CounterWindowSize+1; i++ { + T(i, true) + } + T(1, true) + T(0, false) + + t.Log("Bulk test 3") + filter.Init() + testNumber = 0 + for i := CounterWindowSize + 1; i > 0; i-- { + T(i, true) + } + + t.Log("Bulk test 4") + filter.Init() + testNumber = 0 + for i := CounterWindowSize + 2; i > 1; i-- { + T(i, true) + } + T(0, false) + + t.Log("Bulk test 5") + filter.Init() + testNumber = 0 + for i := CounterWindowSize; i > 0; i-- { + T(i, true) + } + T(CounterWindowSize+1, true) + T(0, false) + + t.Log("Bulk test 6") + filter.Init() + testNumber = 0 + for i := CounterWindowSize; i > 0; i-- { + T(i, true) + } + T(0, true) + T(CounterWindowSize+1, true) +} diff --git a/routing.go b/routing.go new file mode 100644 index 0000000..2a2e237 --- /dev/null +++ b/routing.go @@ -0,0 +1,65 @@ +package main + +import ( + "errors" + "net" + "sync" +) + +type RoutingTable struct { + IPv4 *Trie + IPv6 *Trie + mutex sync.RWMutex +} + +func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.AllowedIPs(peer, allowed) + allowed = table.IPv6.AllowedIPs(peer, allowed) + return allowed +} + +func (table *RoutingTable) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *RoutingTable) RemovePeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.RemovePeer(peer) + table.IPv6 = table.IPv6.RemovePeer(peer) +} + +func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.Insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.Insert(ip, cidr, peer) + default: + panic(errors.New("Inserting unknown address type")) + } +} + +func (table *RoutingTable) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.Lookup(address) +} + +func (table *RoutingTable) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.Lookup(address) +} diff --git a/send.go b/send.go new file mode 100644 index 0000000..7488d3a --- /dev/null +++ b/send.go @@ -0,0 +1,362 @@ +package main + +import ( + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "sync" + "sync/atomic" + "time" +) + +/* Outbound flow + * + * 1. TUN queue + * 2. Routing (sequential) + * 3. Nonce assignment (sequential) + * 4. Encryption (parallel) + * 5. Transmission (sequential) + * + * The functions in this file occur (roughly) in the order in + * which the packets are processed. + * + * Locking, Producers and Consumers + * + * The order of packets (per peer) must be maintained, + * but encryption of packets happen out-of-order: + * + * The sequential consumers will attempt to take the lock, + * workers release lock when they have completed work (encryption) on the packet. + * + * If the element is inserted into the "encryption queue", + * the content is preceded by enough "junk" to contain the transport header + * (to allow the construction of transport messages in-place) + */ + +type QueueOutboundElement struct { + dropped int32 + mutex sync.Mutex + buffer *[MaxMessageSize]byte // slice holding the packet data + packet []byte // slice of "buffer" (always!) + nonce uint64 // nonce for encryption + keyPair *KeyPair // key-pair for encryption + peer *Peer // related peer +} + +func (peer *Peer) FlushNonceQueue() { + elems := len(peer.queue.nonce) + for i := 0; i < elems; i++ { + select { + case <-peer.queue.nonce: + default: + return + } + } +} + +func (device *Device) NewOutboundElement() *QueueOutboundElement { + return &QueueOutboundElement{ + dropped: AtomicFalse, + buffer: device.pool.messageBuffers.Get().(*[MaxMessageSize]byte), + } +} + +func (elem *QueueOutboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueOutboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func addToOutboundQueue( + queue chan *QueueOutboundElement, + element *QueueOutboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + old.Drop() + default: + } + } + } +} + +func addToEncryptionQueue( + queue chan *QueueOutboundElement, + element *QueueOutboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + // drop & release to potential consumer + old.Drop() + old.mutex.Unlock() + default: + } + } + } +} + +/* Reads packets from the TUN and inserts + * into nonce queue for peer + * + * Obs. Single instance per TUN device + */ +func (device *Device) RoutineReadFromTUN() { + + elem := device.NewOutboundElement() + + logDebug := device.log.Debug + logError := device.log.Error + + logDebug.Println("Routine, TUN Reader started") + + for { + + // read packet + + offset := MessageTransportHeaderSize + size, err := device.tun.device.Read(elem.buffer[:], offset) + + if err != nil { + logError.Println("Failed to read packet from TUN device:", err) + device.Close() + return + } + + if size == 0 || size > MaxContentSize { + continue + } + + elem.packet = elem.buffer[offset : offset+size] + + // lookup peer + + var peer *Peer + switch elem.packet[0] >> 4 { + case ipv4.Version: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.routing.table.LookupIPv4(dst) + + case ipv6.Version: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.routing.table.LookupIPv6(dst) + + default: + logDebug.Println("Received packet with unknown IP version") + } + + if peer == nil { + continue + } + + // insert into nonce/pre-handshake queue + + if peer.isRunning.Get() { + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + addToOutboundQueue(peer.queue.nonce, elem) + elem = device.NewOutboundElement() + } + } +} + +/* Queues packets when there is no handshake. + * Then assigns nonces to packets sequentially + * and creates "work" structs for workers + * + * Obs. A single instance per peer + */ +func (peer *Peer) RoutineNonce() { + var keyPair *KeyPair + + defer peer.routines.stopping.Done() + + device := peer.device + logDebug := device.log.Debug + logDebug.Println("Routine, nonce worker, started for peer", peer.String()) + + peer.routines.starting.Done() + + for { + NextPacket: + select { + case <-peer.routines.stop.Wait(): + return + + case elem := <-peer.queue.nonce: + + // wait for key pair + + for { + keyPair = peer.keyPairs.Current() + if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { + if time.Now().Sub(keyPair.created) < RejectAfterTime { + break + } + } + + peer.signal.handshakeBegin.Send() + + logDebug.Println("Awaiting key-pair for", peer.String()) + + select { + case <-peer.signal.newKeyPair.Wait(): + case <-peer.signal.flushNonceQueue.Wait(): + logDebug.Println("Clearing queue for", peer.String()) + peer.FlushNonceQueue() + goto NextPacket + case <-peer.routines.stop.Wait(): + return + } + } + + // populate work element + + elem.peer = peer + elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + elem.keyPair = keyPair + elem.dropped = AtomicFalse + elem.mutex.Lock() + + // add to parallel and sequential queue + + addToEncryptionQueue(device.queue.encryption, elem) + addToOutboundQueue(peer.queue.outbound, elem) + } + } +} + +/* Encrypts the elements in the queue + * and marks them for sequential consumption (by releasing the mutex) + * + * Obs. One instance per core + */ +func (device *Device) RoutineEncryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + logDebug.Println("Routine, encryption worker, started") + + for { + + // fetch next element + + select { + case <-device.signal.stop.Wait(): + logDebug.Println("Routine, encryption worker, stopped") + return + + case elem := <-device.queue.encryption: + + // check if dropped + + if elem.IsDropped() { + continue + } + + // populate header fields + + header := elem.buffer[:MessageTransportHeaderSize] + + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] + + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + + // pad content to multiple of 16 + + mtu := int(atomic.LoadInt32(&device.tun.mtu)) + rem := len(elem.packet) % PaddingMultiple + if rem > 0 { + for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ { + elem.packet = append(elem.packet, 0) + } + } + + // encrypt content and release to consumer + + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keyPair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + elem.mutex.Unlock() + } + } +} + +/* Sequentially reads packets from queue and sends to endpoint + * + * Obs. Single instance per peer. + * The routine terminates then the outbound queue is closed. + */ +func (peer *Peer) RoutineSequentialSender() { + + defer peer.routines.stopping.Done() + + device := peer.device + + logDebug := device.log.Debug + logDebug.Println("Routine, sequential sender, started for", peer.String()) + + peer.routines.starting.Done() + + for { + select { + + case <-peer.routines.stop.Wait(): + logDebug.Println( + "Routine, sequential sender, stopped for", peer.String()) + return + + case elem := <-peer.queue.outbound: + elem.mutex.Lock() + if elem.IsDropped() { + continue + } + + // send message and return buffer to pool + + length := uint64(len(elem.packet)) + err := peer.SendBuffer(elem.packet) + device.PutMessageBuffer(elem.buffer) + if err != nil { + logDebug.Println("Failed to send authenticated packet to peer", peer.String()) + continue + } + atomic.AddUint64(&peer.stats.txBytes, length) + + // update timers + + peer.TimerAnyAuthenticatedPacketTraversal() + if len(elem.packet) != MessageKeepaliveSize { + peer.TimerDataSent() + } + peer.KeepKeyFreshSending() + } + } +} diff --git a/signal.go b/signal.go new file mode 100644 index 0000000..2cefad4 --- /dev/null +++ b/signal.go @@ -0,0 +1,53 @@ +package main + +type Signal struct { + enabled AtomicBool + C chan struct{} +} + +func NewSignal() (s Signal) { + s.C = make(chan struct{}, 1) + s.Enable() + return +} + +func (s *Signal) Disable() { + s.enabled.Set(false) + s.Clear() +} + +func (s *Signal) Enable() { + s.enabled.Set(true) +} + +/* Unblock exactly one listener + */ +func (s *Signal) Send() { + if s.enabled.Get() { + select { + case s.C <- struct{}{}: + default: + } + } +} + +/* Clear the signal if already fired + */ +func (s Signal) Clear() { + select { + case <-s.C: + default: + } +} + +/* Unblocks all listeners (forever) + */ +func (s Signal) Broadcast() { + close(s.C) +} + +/* Wait for the signal + */ +func (s Signal) Wait() chan struct{} { + return s.C +} diff --git a/src/Makefile b/src/Makefile deleted file mode 100644 index 5b23ecc..0000000 --- a/src/Makefile +++ /dev/null @@ -1,12 +0,0 @@ -all: wireguard-go - -wireguard-go: $(wildcard *.go) - go build -o $@ - -clean: - rm -f wireguard-go - -cloc: - cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go)) - -.PHONY: clean cloc diff --git a/src/build.cmd b/src/build.cmd deleted file mode 100755 index 52cb883..0000000 --- a/src/build.cmd +++ /dev/null @@ -1,6 +0,0 @@ -@echo off - -REM builds wireguard for windows - -go get -go build -o wireguard-go.exe diff --git a/src/conn.go b/src/conn.go deleted file mode 100644 index fb30ec2..0000000 --- a/src/conn.go +++ /dev/null @@ -1,128 +0,0 @@ -package main - -import ( - "errors" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - return addr, err -} - -/* Must hold device and net lock - */ -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - return err -} - -func (device *Device) BindUpdate() error { - - device.net.mutex.Lock() - defer device.net.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port) - if err != nil { - netc.bind = nil - return err - } - - // set mark - - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - - // clear cached source addresses - - for _, peer := range device.peers.keyMap { - peer.mutex.Lock() - defer peer.mutex.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - - // start receiving routines - - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.mutex.Lock() - err := unsafeCloseBind(device) - device.net.mutex.Unlock() - return err -} diff --git a/src/conn_default.go b/src/conn_default.go deleted file mode 100644 index 5b73c90..0000000 --- a/src/conn_default.go +++ /dev/null @@ -1,131 +0,0 @@ -// +build !linux - -package main - -import ( - "net" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type NativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*NativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // retrieve port - - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func CreateBind(uport uint16) (Bind, uint16, error) { - var err error - var bind NativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil { - bind.ipv4.Close() - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *NativeBind) Close() error { - err1 := bind.ipv4.Close() - err2 := bind.ipv6.Close() - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To16() != nil { - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} - -func (bind *NativeBind) SetMark(_ uint32) error { - return nil -} diff --git a/src/conn_linux.go b/src/conn_linux.go deleted file mode 100644 index cdba74f..0000000 --- a/src/conn_linux.go +++ /dev/null @@ -1,582 +0,0 @@ -/* Copyright 2017 Jason A. Donenfeld . All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. - */ - -package main - -import ( - "encoding/binary" - "errors" - "golang.org/x/sys/unix" - "net" - "strconv" - "unsafe" -) - -/* Supports source address caching - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. - */ -type NativeEndpoint struct { - src unix.RawSockaddrInet6 - dst unix.RawSockaddrInet6 -} - -type NativeBind struct { - sock4 int - sock6 int -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = NativeBind{} - -type IPv4Source struct { - src unix.RawSockaddrInet4 - Ifindex int32 -} - -func htons(val uint16) uint16 { - var out [unsafe.Sizeof(val)]byte - binary.BigEndian.PutUint16(out[:], val) - return *((*uint16)(unsafe.Pointer(&out[0]))) -} - -func ntohs(val uint16) uint16 { - tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val))) - return binary.BigEndian.Uint16((*tmp)[:]) -} - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - dst.Family = unix.AF_INET - dst.Port = htons(uint16(addr.Port)) - dst.Zero = [8]byte{} - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := &end.dst - dst.Family = unix.AF_INET6 - dst.Port = htons(uint16(addr.Port)) - dst.Flowinfo = 0 - dst.Scope_id = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Failed to recognize IP address format") -} - -func CreateBind(port uint16) (Bind, uint16, error) { - var err error - var bind NativeBind - - bind.sock6, port, err = create6(port) - if err != nil { - return nil, port, err - } - - bind.sock4, port, err = create4(port) - if err != nil { - unix.Close(bind.sock6) - } - return bind, port, err -} - -func (bind NativeBind) SetMark(value uint32) error { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - - return unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers - unix.Shutdown(fd, unix.SHUT_RD) - return unix.Close(fd) -} - -func (bind NativeBind) Close() error { - err1 := closeUnblock(bind.sock6) - err2 := closeUnblock(bind.sock4) - if err1 != nil { - return err1 - } - return err2 -} - -func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind NativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - switch nend.dst.Family { - case unix.AF_INET6: - return send6(bind.sock6, nend, buff) - case unix.AF_INET: - return send4(bind.sock4, nend, buff) - default: - return errors.New("Unknown address family of destination") - } -} - -func sockaddrToString(addr unix.RawSockaddrInet6) string { - var udpAddr net.UDPAddr - - switch addr.Family { - case unix.AF_INET6: - udpAddr.Port = int(ntohs(addr.Port)) - udpAddr.IP = addr.Addr[:] - return udpAddr.String() - - case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - udpAddr.Port = int(ntohs(ptr.Port)) - udpAddr.IP = net.IPv4( - ptr.Addr[0], - ptr.Addr[1], - ptr.Addr[2], - ptr.Addr[3], - ) - return udpAddr.String() - - default: - return "" - } -} - -func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { - switch addr.Family { - case unix.AF_INET6: - return addr.Addr[:] - case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - return net.IPv4( - ptr.Addr[0], - ptr.Addr[1], - ptr.Addr[2], - ptr.Addr[3], - ) - default: - return nil - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - return rawAddrToIP(end.src) -} - -func (end *NativeEndpoint) DstIP() net.IP { - return rawAddrToIP(end.dst) -} - -func (end *NativeEndpoint) DstToBytes() []byte { - ptr := unsafe.Pointer(&end.src) - arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) - return arr[:] -} - -func (end *NativeEndpoint) SrcToString() string { - return sockaddrToString(end.src) -} - -func (end *NativeEndpoint) DstToString() string { - return sockaddrToString(end.dst) -} - -func (end *NativeEndpoint) ClearDst() { - end.dst = unix.RawSockaddrInet6{} -} - -func (end *NativeEndpoint) ClearSrc() { - end.src = unix.RawSockaddrInet6{} -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return -1, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return -1, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - } - - return fd, uint16(addr.Port), err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src.Addr, - Ifindex: end.src.Scope_id, - }, - } - - msghdr := unix.Msghdr{ - Iov: &iovec, - Iovlen: 1, - Name: (*byte)(unsafe.Pointer(&end.dst)), - Namelen: unix.SizeofSockaddrInet6, - Control: (*byte)(unsafe.Pointer(&cmsg)), - } - - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - // sendmsg(sock, &msghdr, 0) - - _, _, errno := unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - - if errno == 0 { - return nil - } - - // clear src and retry - - if errno == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - _, _, errno = unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - } - - return errno -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: src4.src.Addr, - Ifindex: src4.Ifindex, - }, - } - - msghdr := unix.Msghdr{ - Iov: &iovec, - Iovlen: 1, - Name: (*byte)(unsafe.Pointer(&end.dst)), - Namelen: unix.SizeofSockaddrInet4, - Control: (*byte)(unsafe.Pointer(&cmsg)), - Flags: 0, - } - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - // sendmsg(sock, &msghdr, 0) - - _, _, errno := unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - - // clear source and try again - - if errno == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - _, _, errno = unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - } - - // errno = 0 is still an error instance - - if errno == 0 { - return nil - } - - return errno -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - var msghdr unix.Msghdr - msghdr.Iov = &iovec - msghdr.Iovlen = 1 - msghdr.Name = (*byte)(unsafe.Pointer(&end.dst)) - msghdr.Namelen = unix.SizeofSockaddrInet4 - msghdr.Control = (*byte)(unsafe.Pointer(&cmsg)) - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - // recvmsg(sock, &mskhdr, 0) - - size, _, errno := unix.Syscall( - unix.SYS_RECVMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - - if errno != 0 { - return 0, errno - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) - src4.src.Family = unix.AF_INET - src4.src.Addr = cmsg.pktinfo.Spec_dst - src4.Ifindex = cmsg.pktinfo.Ifindex - } - - return int(size), nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - var msg unix.Msghdr - msg.Iov = &iovec - msg.Iovlen = 1 - msg.Name = (*byte)(unsafe.Pointer(&end.dst)) - msg.Namelen = uint32(unix.SizeofSockaddrInet6) - msg.Control = (*byte)(unsafe.Pointer(&cmsg)) - msg.SetControllen(int(unsafe.Sizeof(cmsg))) - - // recvmsg(sock, &mskhdr, 0) - - size, _, errno := unix.Syscall( - unix.SYS_RECVMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msg)), - 0, - ) - - if errno != 0 { - return 0, errno - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src.Family = unix.AF_INET6 - end.src.Addr = cmsg.pktinfo.Addr - end.src.Scope_id = cmsg.pktinfo.Ifindex - } - - return int(size), nil -} diff --git a/src/constants.go b/src/constants.go deleted file mode 100644 index 71dd98e..0000000 --- a/src/constants.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "time" -) - -/* Specification constants */ - -const ( - RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 - RejectAfterTime = time.Second * 180 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 120 - HandshakeInitationRate = time.Second / 20 - PaddingMultiple = 16 -) - -const ( - RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout - NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message -) - -/* Implementation specific constants */ - -const ( - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content -) - -const ( - UnderLoadQueueSize = QueueHandshakeSize / 8 - UnderLoadAfterTime = time.Second // how long does the device remain under load after detected - MaxPeers = 1 << 16 // maximum number of configured peers -) diff --git a/src/cookie.go b/src/cookie.go deleted file mode 100644 index a13ad49..0000000 --- a/src/cookie.go +++ /dev/null @@ -1,252 +0,0 @@ -package main - -import ( - "crypto/hmac" - "crypto/rand" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "sync" - "time" -) - -type CookieChecker struct { - mutex sync.RWMutex - mac1 struct { - key [blake2s.Size]byte - } - mac2 struct { - secret [blake2s.Size]byte - secretSet time.Time - encryptionKey [chacha20poly1305.KeySize]byte - } -} - -type CookieGenerator struct { - mutex sync.RWMutex - mac1 struct { - key [blake2s.Size]byte - } - mac2 struct { - cookie [blake2s.Size128]byte - cookieSet time.Time - hasLastMAC1 bool - lastMAC1 [blake2s.Size128]byte - encryptionKey [chacha20poly1305.KeySize]byte - } -} - -func (st *CookieChecker) Init(pk NoisePublicKey) { - st.mutex.Lock() - defer st.mutex.Unlock() - - // mac1 state - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelMAC1)) - hsh.Write(pk[:]) - hsh.Sum(st.mac1.key[:0]) - }() - - // mac2 state - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelCookie)) - hsh.Write(pk[:]) - hsh.Sum(st.mac2.encryptionKey[:0]) - }() - - st.mac2.secretSet = time.Time{} -} - -func (st *CookieChecker) CheckMAC1(msg []byte) bool { - size := len(msg) - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - var mac1 [blake2s.Size128]byte - - mac, _ := blake2s.New128(st.mac1.key[:]) - mac.Write(msg[:smac1]) - mac.Sum(mac1[:0]) - - return hmac.Equal(mac1[:], msg[smac1:smac2]) -} - -func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { - st.mutex.RLock() - defer st.mutex.RUnlock() - - if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { - return false - } - - // derive cookie key - - var cookie [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src) - mac.Sum(cookie[:0]) - }() - - // calculate mac of packet (including mac1) - - smac2 := len(msg) - blake2s.Size128 - - var mac2 [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(cookie[:]) - mac.Write(msg[:smac2]) - mac.Sum(mac2[:0]) - }() - - return hmac.Equal(mac2[:], msg[smac2:]) -} - -func (st *CookieChecker) CreateReply( - msg []byte, - recv uint32, - src []byte, -) (*MessageCookieReply, error) { - - st.mutex.RLock() - - // refresh cookie secret - - if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { - st.mutex.RUnlock() - st.mutex.Lock() - _, err := rand.Read(st.mac2.secret[:]) - if err != nil { - st.mutex.Unlock() - return nil, err - } - st.mac2.secretSet = time.Now() - st.mutex.Unlock() - st.mutex.RLock() - } - - // derive cookie - - var cookie [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src) - mac.Sum(cookie[:0]) - }() - - // encrypt cookie - - size := len(msg) - - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - reply := new(MessageCookieReply) - reply.Type = MessageCookieReplyType - reply.Receiver = recv - - _, err := rand.Read(reply.Nonce[:]) - if err != nil { - st.mutex.RUnlock() - return nil, err - } - - XChaCha20Poly1305Encrypt( - reply.Cookie[:0], - &reply.Nonce, - cookie[:], - msg[smac1:smac2], - &st.mac2.encryptionKey, - ) - - st.mutex.RUnlock() - - return reply, nil -} - -func (st *CookieGenerator) Init(pk NoisePublicKey) { - st.mutex.Lock() - defer st.mutex.Unlock() - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelMAC1)) - hsh.Write(pk[:]) - hsh.Sum(st.mac1.key[:0]) - }() - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelCookie)) - hsh.Write(pk[:]) - hsh.Sum(st.mac2.encryptionKey[:0]) - }() - - st.mac2.cookieSet = time.Time{} -} - -func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { - st.mutex.Lock() - defer st.mutex.Unlock() - - if !st.mac2.hasLastMAC1 { - return false - } - - var cookie [blake2s.Size128]byte - - _, err := XChaCha20Poly1305Decrypt( - cookie[:0], - &msg.Nonce, - msg.Cookie[:], - st.mac2.lastMAC1[:], - &st.mac2.encryptionKey, - ) - - if err != nil { - return false - } - - st.mac2.cookieSet = time.Now() - st.mac2.cookie = cookie - return true -} - -func (st *CookieGenerator) AddMacs(msg []byte) { - - size := len(msg) - - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - mac1 := msg[smac1:smac2] - mac2 := msg[smac2:] - - st.mutex.Lock() - defer st.mutex.Unlock() - - // set mac1 - - func() { - mac, _ := blake2s.New128(st.mac1.key[:]) - mac.Write(msg[:smac1]) - mac.Sum(mac1[:0]) - }() - copy(st.mac2.lastMAC1[:], mac1) - st.mac2.hasLastMAC1 = true - - // set mac2 - - if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { - return - } - - func() { - mac, _ := blake2s.New128(st.mac2.cookie[:]) - mac.Write(msg[:smac2]) - mac.Sum(mac2[:0]) - }() -} diff --git a/src/cookie_test.go b/src/cookie_test.go deleted file mode 100644 index d745fe7..0000000 --- a/src/cookie_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package main - -import ( - "testing" -) - -func TestCookieMAC1(t *testing.T) { - - // setup generator / checker - - var ( - generator CookieGenerator - checker CookieChecker - ) - - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - pk := sk.publicKey() - - generator.Init(pk) - checker.Init(pk) - - // check mac1 - - src := []byte{192, 168, 13, 37, 10, 10, 10} - - checkMAC1 := func(msg []byte) { - generator.AddMacs(msg) - if !checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - } - - checkMAC1([]byte{ - 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, - 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, - 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, - 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, - 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, - 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, - 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, - 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, - 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, - 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, - 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, - 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, - 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, - 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, - 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, - 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, - }) - - checkMAC1([]byte{ - 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, - 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, - 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, - 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, - 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, - 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, - 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, - 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, - }) - - checkMAC1([]byte{ - 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, - 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, - 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, - 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, - }) - - // exchange cookie reply - - func() { - msg := []byte{ - 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, - 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, - 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, - 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, - 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, - 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, - 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, - 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, - 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, - 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, - 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, - 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, - 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, - 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, - 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, - 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, - } - generator.AddMacs(msg) - reply, err := checker.CreateReply(msg, 1377, src) - if err != nil { - t.Fatal("Failed to create cookie reply:", err) - } - if !generator.ConsumeReply(reply) { - t.Fatal("Failed to consume cookie reply") - } - }() - - // check mac2 - - checkMAC2 := func(msg []byte) { - generator.AddMacs(msg) - - if !checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if !checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - - msg[5] ^= 0x20 - - if checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - - msg[5] ^= 0x20 - - srcBad1 := []byte{192, 168, 13, 37, 40, 01} - if checker.CheckMAC2(msg, srcBad1) { - t.Fatal("MAC2 generation/verification failed") - } - - srcBad2 := []byte{192, 168, 13, 38, 40, 01} - if checker.CheckMAC2(msg, srcBad2) { - t.Fatal("MAC2 generation/verification failed") - } - } - - checkMAC2([]byte{ - 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, - 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, - 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, - 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, - 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, - 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, - 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, - 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, - }) - - checkMAC2([]byte{ - 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, - 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, - 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, - 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, - 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, - 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, - 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, - 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, - 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, - 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, - 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, - 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, - 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, - 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, - 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, - 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, - 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, - 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, - 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, - 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, - 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, - 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, - 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, - 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, - 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, - 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, - 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, - 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, - 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, - 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, - 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, - 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, - }) -} diff --git a/src/daemon_darwin.go b/src/daemon_darwin.go deleted file mode 100644 index 913af0e..0000000 --- a/src/daemon_darwin.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -import ( - "errors" -) - -func Daemonize() error { - return errors.New("Not implemented on OSX") -} diff --git a/src/daemon_linux.go b/src/daemon_linux.go deleted file mode 100644 index e1aaede..0000000 --- a/src/daemon_linux.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "os" - "os/exec" -) - -/* Daemonizes the process on linux - * - * This is done by spawning and releasing a copy with the --foreground flag - */ -func Daemonize(attr *os.ProcAttr) error { - // I would like to use os.Executable, - // however this means dropping support for Go <1.8 - path, err := exec.LookPath(os.Args[0]) - if err != nil { - return err - } - - argv := []string{os.Args[0], "--foreground"} - argv = append(argv, os.Args[1:]...) - process, err := os.StartProcess( - path, - argv, - attr, - ) - if err != nil { - return err - } - process.Release() - return nil -} diff --git a/src/daemon_windows.go b/src/daemon_windows.go deleted file mode 100644 index d5ec1e8..0000000 --- a/src/daemon_windows.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "os" -) - -/* Daemonizes the process on windows - * - * This is done by spawning and releasing a copy with the --foreground flag - */ - -func Daemonize() error { - argv := []string{os.Args[0], "--foreground"} - argv = append(argv, os.Args[1:]...) - attr := &os.ProcAttr{ - Dir: ".", - Env: os.Environ(), - Files: []*os.File{ - os.Stdin, - nil, - nil, - }, - } - process, err := os.StartProcess( - argv[0], - argv, - attr, - ) - if err != nil { - return err - } - process.Release() - return nil -} diff --git a/src/device.go b/src/device.go deleted file mode 100644 index c041987..0000000 --- a/src/device.go +++ /dev/null @@ -1,372 +0,0 @@ -package main - -import ( - "github.com/sasha-s/go-deadlock" - "runtime" - "sync" - "sync/atomic" - "time" -) - -type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger - - // synchronized resources (locks acquired in order) - - state struct { - mutex deadlock.Mutex - changing AtomicBool - current bool - } - - net struct { - mutex deadlock.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - } - - noise struct { - mutex deadlock.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey - } - - routing struct { - mutex deadlock.RWMutex - table RoutingTable - } - - peers struct { - mutex deadlock.RWMutex - keyMap map[NoisePublicKey]*Peer - } - - // unprotected / "self-synchronising resources" - - indices IndexTable - mac CookieChecker - - rate struct { - underLoadUntil atomic.Value - limiter Ratelimiter - } - - pool struct { - messageBuffers sync.Pool - } - - queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - - signal struct { - stop Signal - } - - tun struct { - device TUNDevice - mtu int32 - } -} - -/* Converts the peer into a "zombie", which remains in the peer map, - * but processes no packets and does not exists in the routing table. - * - * Must hold: - * device.peers.mutex : exclusive lock - * device.routing : exclusive lock - */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { - - // stop routing and processing of packets - - device.routing.table.RemovePeer(peer) - peer.Stop() - - // remove from peer map - - delete(device.peers.keyMap, key) -} - -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { - return - } - - func() { - - // compare to current state of device - - device.state.mutex.Lock() - defer device.state.mutex.Unlock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - return - } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.isUp.Set(false) - break - } - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for _, peer := range device.peers.keyMap { - peer.Start() - } - - case false: - device.BindClose() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for _, peer := range device.peers.keyMap { - println("stopping peer") - peer.Stop() - } - } - - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - }() - - // check for state change in the mean time - - deviceUpdateState(device) -} - -func (device *Device) Up() { - - // closed device cannot be brought up - - if device.isClosed.Get() { - return - } - - device.state.mutex.Lock() - device.isUp.Set(true) - device.state.mutex.Unlock() - deviceUpdateState(device) -} - -func (device *Device) Down() { - device.state.mutex.Lock() - device.isUp.Set(false) - device.state.mutex.Unlock() - deviceUpdateState(device) -} - -func (device *Device) IsUnderLoad() bool { - - // check if currently under load - - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize - if underLoad { - device.rate.underLoadUntil.Store(now.Add(time.Second)) - return true - } - - // check if recently under load - - until := device.rate.underLoadUntil.Load().(time.Time) - return until.After(now) -} - -func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { - - // lock required resources - - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for _, peer := range device.peers.keyMap { - peer.handshake.mutex.RLock() - defer peer.handshake.mutex.RUnlock() - } - - // remove peers with matching public keys - - publicKey := sk.publicKey() - for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equals(publicKey) { - unsafeRemovePeer(device, peer, key) - } - } - - // update key material - - device.noise.privateKey = sk - device.noise.publicKey = publicKey - device.mac.Init(publicKey) - - // do static-static DH pre-computations - - rmKey := device.noise.privateKey.IsZero() - - for key, peer := range device.peers.keyMap { - - hs := &peer.handshake - - if rmKey { - hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} - } else { - hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) - } - - if isZero(hs.precomputedStaticStatic[:]) { - unsafeRemovePeer(device, peer, key) - } - } - - return nil -} - -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) -} - -func NewDevice(tun TUNDevice, logger *Logger) *Device { - device := new(Device) - - device.isUp.Set(false) - device.isClosed.Set(false) - - device.log = logger - device.tun.device = tun - device.peers.keyMap = make(map[NoisePublicKey]*Peer) - - // initialize anti-DoS / anti-scanning features - - device.rate.limiter.Init() - device.rate.underLoadUntil.Store(time.Time{}) - - // initialize noise & crypt-key routine - - device.indices.Init() - device.routing.table.Reset() - - // setup buffer pool - - device.pool.messageBuffers = sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - - // create queues - - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signal.stop = NewSignal() - - // prepare net - - device.net.port = 0 - device.net.bind = nil - - // start workers - - for i := 0; i < runtime.NumCPU(); i += 1 { - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() - } - - go device.RoutineReadFromTUN() - go device.RoutineTUNEventReader() - go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) - - return device -} - -func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { - device.peers.mutex.RLock() - defer device.peers.mutex.RUnlock() - - return device.peers.keyMap[pk] -} - -func (device *Device) RemovePeer(key NoisePublicKey) { - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // stop peer and remove from routing - - peer, ok := device.peers.keyMap[key] - if ok { - unsafeRemovePeer(device, peer, key) - } -} - -func (device *Device) RemoveAllPeers() { - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for key, peer := range device.peers.keyMap { - println("rm", peer.String()) - unsafeRemovePeer(device, peer, key) - } - - device.peers.keyMap = make(map[NoisePublicKey]*Peer) -} - -func (device *Device) Close() { - device.log.Info.Println("Device closing") - if device.isClosed.Swap(true) { - return - } - device.signal.stop.Broadcast() - device.tun.device.Close() - device.BindClose() - device.isUp.Set(false) - device.RemoveAllPeers() - device.log.Info.Println("Interface closed") -} - -func (device *Device) Wait() chan struct{} { - return device.signal.stop.Wait() -} diff --git a/src/helper_test.go b/src/helper_test.go deleted file mode 100644 index 41e6b72..0000000 --- a/src/helper_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "bytes" - "os" - "testing" -) - -/* Helpers for writing unit tests - */ - -type DummyTUN struct { - name string - mtu int - packets chan []byte - events chan TUNEvent -} - -func (tun *DummyTUN) File() *os.File { - return nil -} - -func (tun *DummyTUN) Name() string { - return tun.name -} - -func (tun *DummyTUN) MTU() (int, error) { - return tun.mtu, nil -} - -func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { - tun.packets <- d[offset:] - return len(d), nil -} - -func (tun *DummyTUN) Close() error { - return nil -} - -func (tun *DummyTUN) Events() chan TUNEvent { - return tun.events -} - -func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { - t := <-tun.packets - copy(d[offset:], t) - return len(t), nil -} - -func CreateDummyTUN(name string) (TUNDevice, error) { - var dummy DummyTUN - dummy.mtu = 0 - dummy.packets = make(chan []byte, 100) - return &dummy, nil -} - -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} - -func assertEqual(t *testing.T, a []byte, b []byte) { - if bytes.Compare(a, b) != 0 { - t.Fatal(a, "!=", b) - } -} - -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - tun, _ := CreateDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device -} diff --git a/src/index.go b/src/index.go deleted file mode 100644 index 1ba040e..0000000 --- a/src/index.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "crypto/rand" - "encoding/binary" - "sync" -) - -/* Index=0 is reserved for unset indecies - * - */ - -type IndexTableEntry struct { - peer *Peer - handshake *Handshake - keyPair *KeyPair -} - -type IndexTable struct { - mutex sync.RWMutex - table map[uint32]IndexTableEntry -} - -func randUint32() (uint32, error) { - var buff [4]byte - _, err := rand.Read(buff[:]) - value := binary.LittleEndian.Uint32(buff[:]) - return value, err -} - -func (table *IndexTable) Init() { - table.mutex.Lock() - table.table = make(map[uint32]IndexTableEntry) - table.mutex.Unlock() -} - -func (table *IndexTable) Delete(index uint32) { - if index == 0 { - return - } - table.mutex.Lock() - delete(table.table, index) - table.mutex.Unlock() -} - -func (table *IndexTable) Insert(key uint32, value IndexTableEntry) { - table.mutex.Lock() - table.table[key] = value - table.mutex.Unlock() -} - -func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { - for { - // generate random index - - index, err := randUint32() - if err != nil { - return index, err - } - if index == 0 { - continue - } - - // check if index used - - table.mutex.RLock() - _, ok := table.table[index] - table.mutex.RUnlock() - if ok { - continue - } - - // map index to handshake - - table.mutex.Lock() - _, found := table.table[index] - if found { - table.mutex.Unlock() - continue - } - table.table[index] = IndexTableEntry{ - peer: peer, - handshake: &peer.handshake, - keyPair: nil, - } - table.mutex.Unlock() - return index, nil - } -} - -func (table *IndexTable) Lookup(id uint32) IndexTableEntry { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.table[id] -} diff --git a/src/ip.go b/src/ip.go deleted file mode 100644 index 752a404..0000000 --- a/src/ip.go +++ /dev/null @@ -1,17 +0,0 @@ -package main - -import ( - "net" -) - -const ( - IPv4offsetTotalLength = 2 - IPv4offsetSrc = 12 - IPv4offsetDst = IPv4offsetSrc + net.IPv4len -) - -const ( - IPv6offsetPayloadLength = 4 - IPv6offsetSrc = 8 - IPv6offsetDst = IPv6offsetSrc + net.IPv6len -) diff --git a/src/kdf_test.go b/src/kdf_test.go deleted file mode 100644 index a89dacc..0000000 --- a/src/kdf_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "encoding/hex" - "golang.org/x/crypto/blake2s" - "testing" -) - -type KDFTest struct { - key string - input string - t0 string - t1 string - t2 string -} - -func assertEquals(t *testing.T, a string, b string) { - if a != b { - t.Fatal("expected", a, "=", b) - } -} - -func TestKDF(t *testing.T) { - tests := []KDFTest{ - { - key: "746573742d6b6579", - input: "746573742d696e707574", - t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", - t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", - t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", - }, - { - key: "776972656775617264", - input: "776972656775617264", - t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", - t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", - t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", - }, - { - key: "", - input: "", - t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", - t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", - t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", - }, - } - - var t0, t1, t2 [blake2s.Size]byte - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF3(&t0, &t1, &t2, key, input) - t0s := hex.EncodeToString(t0[:]) - t1s := hex.EncodeToString(t1[:]) - t2s := hex.EncodeToString(t2[:]) - assertEquals(t, t0s, test.t0) - assertEquals(t, t1s, test.t1) - assertEquals(t, t2s, test.t2) - } - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF2(&t0, &t1, key, input) - t0s := hex.EncodeToString(t0[:]) - t1s := hex.EncodeToString(t1[:]) - assertEquals(t, t0s, test.t0) - assertEquals(t, t1s, test.t1) - } - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF1(&t0, key, input) - t0s := hex.EncodeToString(t0[:]) - assertEquals(t, t0s, test.t0) - } -} diff --git a/src/keypair.go b/src/keypair.go deleted file mode 100644 index 283cb92..0000000 --- a/src/keypair.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -import ( - "crypto/cipher" - "sync" - "time" -) - -/* Due to limitations in Go and /x/crypto there is currently - * no way to ensure that key material is securely ereased in memory. - * - * Since this may harm the forward secrecy property, - * we plan to resolve this issue; whenever Go allows us to do so. - */ - -type KeyPair struct { - send cipher.AEAD - receive cipher.AEAD - replayFilter ReplayFilter - sendNonce uint64 - isInitiator bool - created time.Time - localIndex uint32 - remoteIndex uint32 -} - -type KeyPairs struct { - mutex sync.RWMutex - current *KeyPair - previous *KeyPair - next *KeyPair // not yet "confirmed by transport" -} - -func (kp *KeyPairs) Current() *KeyPair { - kp.mutex.RLock() - defer kp.mutex.RUnlock() - return kp.current -} - -func (device *Device) DeleteKeyPair(key *KeyPair) { - if key != nil { - device.indices.Delete(key.localIndex) - } -} diff --git a/src/logger.go b/src/logger.go deleted file mode 100644 index 0872ef9..0000000 --- a/src/logger.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "io" - "io/ioutil" - "log" - "os" -) - -const ( - LogLevelError = iota - LogLevelInfo - LogLevelDebug -) - -type Logger struct { - Debug *log.Logger - Info *log.Logger - Error *log.Logger -} - -func NewLogger(level int, prepend string) *Logger { - output := os.Stdout - logger := new(Logger) - - logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { - if level >= LogLevelDebug { - return output, output, output - } - if level >= LogLevelInfo { - return output, output, ioutil.Discard - } - return output, ioutil.Discard, ioutil.Discard - }() - - logger.Debug = log.New(logDebug, - "DEBUG: "+prepend, - log.Ldate|log.Ltime|log.Lshortfile, - ) - - logger.Info = log.New(logInfo, - "INFO: "+prepend, - log.Ldate|log.Ltime, - ) - logger.Error = log.New(logErr, - "ERROR: "+prepend, - log.Ldate|log.Ltime, - ) - return logger -} diff --git a/src/main.go b/src/main.go deleted file mode 100644 index b12bb09..0000000 --- a/src/main.go +++ /dev/null @@ -1,196 +0,0 @@ -package main - -import ( - "fmt" - "os" - "os/signal" - "runtime" - "strconv" -) - -const ( - ExitSetupSuccess = 0 - ExitSetupFailed = 1 -) - -const ( - ENV_WG_TUN_FD = "WG_TUN_FD" - ENV_WG_UAPI_FD = "WG_UAPI_FD" -) - -func printUsage() { - fmt.Printf("usage:\n") - fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) -} - -func main() { - - // parse arguments - - var foreground bool - var interfaceName string - if len(os.Args) < 2 || len(os.Args) > 3 { - printUsage() - return - } - - switch os.Args[1] { - - case "-f", "--foreground": - foreground = true - if len(os.Args) != 3 { - printUsage() - return - } - interfaceName = os.Args[2] - - default: - foreground = false - if len(os.Args) != 2 { - printUsage() - return - } - interfaceName = os.Args[1] - } - - // get log level (default: info) - - logLevel := func() int { - switch os.Getenv("LOG_LEVEL") { - case "debug": - return LogLevelDebug - case "info": - return LogLevelInfo - case "error": - return LogLevelError - } - return LogLevelInfo - }() - - logger := NewLogger( - logLevel, - fmt.Sprintf("(%s) ", interfaceName), - ) - - logger.Debug.Println("Debug log enabled") - - // open TUN device (or use supplied fd) - - tun, err := func() (TUNDevice, error) { - tunFdStr := os.Getenv(ENV_WG_TUN_FD) - if tunFdStr == "" { - return CreateTUN(interfaceName) - } - - // construct tun device from supplied fd - - fd, err := strconv.ParseUint(tunFdStr, 10, 32) - if err != nil { - return nil, err - } - - file := os.NewFile(uintptr(fd), "") - return CreateTUNFromFile(interfaceName, file) - }() - - if err != nil { - logger.Error.Println("Failed to create TUN device:", err) - os.Exit(ExitSetupFailed) - } - - // open UAPI file (or use supplied fd) - - fileUAPI, err := func() (*os.File, error) { - uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) - if uapiFdStr == "" { - return UAPIOpen(interfaceName) - } - - // use supplied fd - - fd, err := strconv.ParseUint(uapiFdStr, 10, 32) - if err != nil { - return nil, err - } - - return os.NewFile(uintptr(fd), ""), nil - }() - - if err != nil { - logger.Error.Println("UAPI listen error:", err) - os.Exit(ExitSetupFailed) - return - } - // daemonize the process - - if !foreground { - env := os.Environ() - env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) - env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) - attr := &os.ProcAttr{ - Files: []*os.File{ - nil, // stdin - nil, // stdout - nil, // stderr - tun.File(), - fileUAPI, - }, - Dir: ".", - Env: env, - } - err = Daemonize(attr) - if err != nil { - logger.Error.Println("Failed to daemonize:", err) - os.Exit(ExitSetupFailed) - } - return - } - - // increase number of go workers (for Go <1.5) - - runtime.GOMAXPROCS(runtime.NumCPU()) - - // create wireguard device - - device := NewDevice(tun, logger) - - logger.Info.Println("Device started") - - // start uapi listener - - errs := make(chan error) - term := make(chan os.Signal) - - uapi, err := UAPIListen(interfaceName, fileUAPI) - - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - errs <- err - return - } - go ipcHandle(device, conn) - } - }() - - logger.Info.Println("UAPI listener started") - - // wait for program to terminate - - signal.Notify(term, os.Kill) - signal.Notify(term, os.Interrupt) - - select { - case <-term: - case <-errs: - case <-device.Wait(): - } - - // clean up - - uapi.Close() - device.Close() - - logger.Info.Println("Shutting down") -} diff --git a/src/misc.go b/src/misc.go deleted file mode 100644 index 80e33f6..0000000 --- a/src/misc.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - flag int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.flag) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.flag, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.flag, flag) -} - -/* Integer manipulation */ - -func toInt32(n uint32) int32 { - mask := uint32(1 << 31) - return int32(-(n & mask) + (n & ^mask)) -} - -func min(a uint, b uint) uint { - if a > b { - return b - } - return a -} - -func minUint64(a uint64, b uint64) uint64 { - if a > b { - return b - } - return a -} diff --git a/src/noise_helpers.go b/src/noise_helpers.go deleted file mode 100644 index 1e2de5f..0000000 --- a/src/noise_helpers.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/subtle" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/curve25519" - "hash" -) - -/* KDF related functions. - * HMAC-based Key Derivation Function (HKDF) - * https://tools.ietf.org/html/rfc5869 - */ - -func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { - mac := hmac.New(func() hash.Hash { - h, _ := blake2s.New256(nil) - return h - }, key) - mac.Write(in0) - mac.Sum(sum[:0]) -} - -func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { - mac := hmac.New(func() hash.Hash { - h, _ := blake2s.New256(nil) - return h - }, key) - mac.Write(in0) - mac.Write(in1) - mac.Sum(sum[:0]) -} - -func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { - HMAC1(t0, key, input) - HMAC1(t0, t0[:], []byte{0x1}) - return -} - -func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { - var prk [blake2s.Size]byte - HMAC1(&prk, key, input) - HMAC1(t0, prk[:], []byte{0x1}) - HMAC2(t1, prk[:], t0[:], []byte{0x2}) - setZero(prk[:]) - return -} - -func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { - var prk [blake2s.Size]byte - HMAC1(&prk, key, input) - HMAC1(t0, prk[:], []byte{0x1}) - HMAC2(t1, prk[:], t0[:], []byte{0x2}) - HMAC2(t2, prk[:], t1[:], []byte{0x3}) - setZero(prk[:]) - return -} - -func isZero(val []byte) bool { - acc := 1 - for _, b := range val { - acc &= subtle.ConstantTimeByteEq(b, 0) - } - return acc == 1 -} - -func setZero(arr []byte) { - for i := range arr { - arr[i] = 0 - } -} - -/* curve25519 wrappers */ - -func newPrivateKey() (sk NoisePrivateKey, err error) { - // clamping: https://cr.yp.to/ecdh.html - _, err = rand.Read(sk[:]) - sk[0] &= 248 - sk[31] &= 127 - sk[31] |= 64 - return -} - -func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarBaseMult(apk, ask) - return -} - -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarMult(&ss, ask, apk) - return ss -} diff --git a/src/noise_protocol.go b/src/noise_protocol.go deleted file mode 100644 index c9713c0..0000000 --- a/src/noise_protocol.go +++ /dev/null @@ -1,578 +0,0 @@ -package main - -import ( - "errors" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/poly1305" - "sync" - "time" -) - -const ( - HandshakeZeroed = iota - HandshakeInitiationCreated - HandshakeInitiationConsumed - HandshakeResponseCreated - HandshakeResponseConsumed -) - -const ( - NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" - WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" - WGLabelMAC1 = "mac1----" - WGLabelCookie = "cookie--" -) - -const ( - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 -) - -const ( - MessageInitiationSize = 148 // size of handshake initation message - MessageResponseSize = 92 // size of response message - MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceeding content in transport message - MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport - MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message -) - -const ( - MessageTransportOffsetReceiver = 4 - MessageTransportOffsetCounter = 8 - MessageTransportOffsetContent = 16 -) - -/* Type is an 8-bit field, followed by 3 nul bytes, - * by marshalling the messages in little-endian byteorder - * we can treat these as a 32-bit unsigned int (for now) - * - */ - -type MessageInitiation struct { - Type uint32 - Sender uint32 - Ephemeral NoisePublicKey - Static [NoisePublicKeySize + poly1305.TagSize]byte - Timestamp [TAI64NSize + poly1305.TagSize]byte - MAC1 [blake2s.Size128]byte - MAC2 [blake2s.Size128]byte -} - -type MessageResponse struct { - Type uint32 - Sender uint32 - Receiver uint32 - Ephemeral NoisePublicKey - Empty [poly1305.TagSize]byte - MAC1 [blake2s.Size128]byte - MAC2 [blake2s.Size128]byte -} - -type MessageTransport struct { - Type uint32 - Receiver uint32 - Counter uint64 - Content []byte -} - -type MessageCookieReply struct { - Type uint32 - Receiver uint32 - Nonce [24]byte - Cookie [blake2s.Size128 + poly1305.TagSize]byte -} - -type Handshake struct { - state int - mutex sync.RWMutex - hash [blake2s.Size]byte // hash value - chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk - localEphemeral NoisePrivateKey // ephemeral secret key - localIndex uint32 // used to clear hash-table - remoteIndex uint32 // index for sending - remoteStatic NoisePublicKey // long term key - remoteEphemeral NoisePublicKey // ephemeral public key - precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret - lastTimestamp TAI64N - lastInitiationConsumption time.Time -} - -var ( - InitialChainKey [blake2s.Size]byte - InitialHash [blake2s.Size]byte - ZeroNonce [chacha20poly1305.NonceSize]byte -) - -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { - KDF1(dst, c[:], data) -} - -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { - hsh, _ := blake2s.New256(nil) - hsh.Write(h[:]) - hsh.Write(data) - hsh.Sum(dst[:0]) - hsh.Reset() -} - -func (h *Handshake) Clear() { - setZero(h.localEphemeral[:]) - setZero(h.remoteEphemeral[:]) - setZero(h.chainKey[:]) - setZero(h.hash[:]) - h.localIndex = 0 - h.state = HandshakeZeroed -} - -func (h *Handshake) mixHash(data []byte) { - mixHash(&h.hash, &h.hash, data) -} - -func (h *Handshake) mixKey(data []byte) { - mixKey(&h.chainKey, &h.chainKey, data) -} - -/* Do basic precomputations - */ -func init() { - InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) - mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) -} - -func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("Static shared secret is zero") - } - - // create ephemeral key - - var err error - handshake.hash = InitialHash - handshake.chainKey = InitialChainKey - handshake.localEphemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - - // assign index - - device.indices.Delete(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) - - if err != nil { - return nil, err - } - - handshake.mixHash(handshake.remoteStatic[:]) - - msg := MessageInitiation{ - Type: MessageInitiationType, - Ephemeral: handshake.localEphemeral.publicKey(), - Sender: handshake.localIndex, - } - - handshake.mixKey(msg.Ephemeral[:]) - handshake.mixHash(msg.Ephemeral[:]) - - // encrypt static key - - func() { - var key [chacha20poly1305.KeySize]byte - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) - }() - handshake.mixHash(msg.Static[:]) - - // encrypt timestamp - - timestamp := Timestamp() - func() { - var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - }() - - handshake.mixHash(msg.Timestamp[:]) - handshake.state = HandshakeInitiationCreated - return &msg, nil -} - -func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { - var ( - hash [blake2s.Size]byte - chainKey [blake2s.Size]byte - ) - - if msg.Type != MessageInitiationType { - return nil - } - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - mixHash(&hash, &InitialHash, device.noise.publicKey[:]) - mixHash(&hash, &hash, msg.Ephemeral[:]) - mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) - - // decrypt static key - - var err error - var peerPK NoisePublicKey - func() { - var key [chacha20poly1305.KeySize]byte - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) - KDF2(&chainKey, &key, chainKey[:], ss[:]) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) - }() - if err != nil { - return nil - } - mixHash(&hash, &hash, msg.Static[:]) - - // lookup peer - - peer := device.LookupPeer(peerPK) - if peer == nil { - return nil - } - - handshake := &peer.handshake - if isZero(handshake.precomputedStaticStatic[:]) { - return nil - } - - // verify identity - - var timestamp TAI64N - var key [chacha20poly1305.KeySize]byte - - handshake.mutex.RLock() - KDF2( - &chainKey, - &key, - chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) - if err != nil { - handshake.mutex.RUnlock() - return nil - } - mixHash(&hash, &hash, msg.Timestamp[:]) - - // protect against replay & flood - - var ok bool - ok = timestamp.After(handshake.lastTimestamp) - ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate - handshake.mutex.RUnlock() - if !ok { - return nil - } - - // update handshake state - - handshake.mutex.Lock() - - handshake.hash = hash - handshake.chainKey = chainKey - handshake.remoteIndex = msg.Sender - handshake.remoteEphemeral = msg.Ephemeral - handshake.lastTimestamp = timestamp - handshake.lastInitiationConsumption = time.Now() - handshake.state = HandshakeInitiationConsumed - - handshake.mutex.Unlock() - - return peer -} - -func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - if handshake.state != HandshakeInitiationConsumed { - return nil, errors.New("handshake initation must be consumed first") - } - - // assign index - - var err error - device.indices.Delete(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) - if err != nil { - return nil, err - } - - var msg MessageResponse - msg.Type = MessageResponseType - msg.Sender = handshake.localIndex - msg.Receiver = handshake.remoteIndex - - // create ephemeral key - - handshake.localEphemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - msg.Ephemeral = handshake.localEphemeral.publicKey() - handshake.mixHash(msg.Ephemeral[:]) - handshake.mixKey(msg.Ephemeral[:]) - - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() - - // add preshared key (psk) - - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - - KDF3( - &handshake.chainKey, - &tau, - &key, - handshake.chainKey[:], - handshake.presharedKey[:], - ) - - handshake.mixHash(tau[:]) - - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() - - handshake.state = HandshakeResponseCreated - - return &msg, nil -} - -func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - if msg.Type != MessageResponseType { - return nil - } - - // lookup handshake by reciever - - lookup := device.indices.Lookup(msg.Receiver) - handshake := lookup.handshake - if handshake == nil { - return nil - } - - var ( - hash [blake2s.Size]byte - chainKey [blake2s.Size]byte - ) - - ok := func() bool { - - // lock handshake state - - handshake.mutex.RLock() - defer handshake.mutex.RUnlock() - - if handshake.state != HandshakeInitiationCreated { - return false - } - - // lock private key for reading - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - // finish 3-way DH - - mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) - mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() - - func() { - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() - - // add preshared key (psk) - - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - KDF3( - &chainKey, - &tau, - &key, - chainKey[:], - handshake.presharedKey[:], - ) - mixHash(&hash, &hash, tau[:]) - - // authenticate transcript - - aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) - if err != nil { - device.log.Debug.Println("failed to open") - return false - } - mixHash(&hash, &hash, msg.Empty[:]) - return true - }() - - if !ok { - return nil - } - - // update handshake state - - handshake.mutex.Lock() - - handshake.hash = hash - handshake.chainKey = chainKey - handshake.remoteIndex = msg.Sender - handshake.state = HandshakeResponseConsumed - - handshake.mutex.Unlock() - - setZero(hash[:]) - setZero(chainKey[:]) - - return lookup.peer -} - -/* Derives a new key-pair from the current handshake state - * - */ -func (peer *Peer) NewKeyPair() *KeyPair { - device := peer.device - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - // derive keys - - var isInitiator bool - var sendKey [chacha20poly1305.KeySize]byte - var recvKey [chacha20poly1305.KeySize]byte - - if handshake.state == HandshakeResponseConsumed { - KDF2( - &sendKey, - &recvKey, - handshake.chainKey[:], - nil, - ) - isInitiator = true - } else if handshake.state == HandshakeResponseCreated { - KDF2( - &recvKey, - &sendKey, - handshake.chainKey[:], - nil, - ) - isInitiator = false - } else { - return nil - } - - // zero handshake - - setZero(handshake.chainKey[:]) - setZero(handshake.localEphemeral[:]) - peer.handshake.state = HandshakeZeroed - - // create AEAD instances - - keyPair := new(KeyPair) - keyPair.send, _ = chacha20poly1305.New(sendKey[:]) - keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) - - setZero(sendKey[:]) - setZero(recvKey[:]) - - keyPair.created = time.Now() - keyPair.sendNonce = 0 - keyPair.replayFilter.Init() - keyPair.isInitiator = isInitiator - keyPair.localIndex = peer.handshake.localIndex - keyPair.remoteIndex = peer.handshake.remoteIndex - - // remap index - - device.indices.Insert( - handshake.localIndex, - IndexTableEntry{ - peer: peer, - keyPair: keyPair, - handshake: nil, - }, - ) - handshake.localIndex = 0 - - // rotate key pairs - - kp := &peer.keyPairs - kp.mutex.Lock() - - if isInitiator { - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - kp.previous = nil - } - - if kp.next != nil { - kp.previous = kp.next - kp.next = keyPair - } else { - kp.previous = kp.current - kp.current = keyPair - peer.signal.newKeyPair.Send() - } - - } else { - kp.next = keyPair - kp.previous = nil - } - kp.mutex.Unlock() - - return keyPair -} diff --git a/src/noise_test.go b/src/noise_test.go deleted file mode 100644 index 5e9d44b..0000000 --- a/src/noise_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "testing" -) - -func TestCurveWrappers(t *testing.T) { - sk1, err := newPrivateKey() - assertNil(t, err) - - sk2, err := newPrivateKey() - assertNil(t, err) - - pk1 := sk1.publicKey() - pk2 := sk2.publicKey() - - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) - - if ss1 != ss2 { - t.Fatal("Failed to compute shared secet") - } -} - -func TestNoiseHandshake(t *testing.T) { - dev1 := randDevice(t) - dev2 := randDevice(t) - - defer dev1.Close() - defer dev2.Close() - - peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey()) - - assertEqual( - t, - peer1.handshake.precomputedStaticStatic[:], - peer2.handshake.precomputedStaticStatic[:], - ) - - /* simulate handshake */ - - // initiation message - - t.Log("exchange initiation message") - - msg1, err := dev1.CreateMessageInitiation(peer2) - assertNil(t, err) - - packet := make([]byte, 0, 256) - writer := bytes.NewBuffer(packet) - err = binary.Write(writer, binary.LittleEndian, msg1) - peer := dev2.ConsumeMessageInitiation(msg1) - if peer == nil { - t.Fatal("handshake failed at initiation message") - } - - assertEqual( - t, - peer1.handshake.chainKey[:], - peer2.handshake.chainKey[:], - ) - - assertEqual( - t, - peer1.handshake.hash[:], - peer2.handshake.hash[:], - ) - - // response message - - t.Log("exchange response message") - - msg2, err := dev2.CreateMessageResponse(peer1) - assertNil(t, err) - - peer = dev1.ConsumeMessageResponse(msg2) - if peer == nil { - t.Fatal("handshake failed at response message") - } - - assertEqual( - t, - peer1.handshake.chainKey[:], - peer2.handshake.chainKey[:], - ) - - assertEqual( - t, - peer1.handshake.hash[:], - peer2.handshake.hash[:], - ) - - // key pairs - - t.Log("deriving keys") - - key1 := peer1.NewKeyPair() - key2 := peer2.NewKeyPair() - - if key1 == nil { - t.Fatal("failed to dervice key-pair for peer 1") - } - - if key2 == nil { - t.Fatal("failed to dervice key-pair for peer 2") - } - - // encrypting / decryption test - - t.Log("test key pairs") - - func() { - testMsg := []byte("wireguard test message 1") - var err error - var out []byte - var nonce [12]byte - out = key1.send.Seal(out, nonce[:], testMsg, nil) - out, err = key2.receive.Open(out[:0], nonce[:], out, nil) - assertNil(t, err) - assertEqual(t, out, testMsg) - }() - - func() { - testMsg := []byte("wireguard test message 2") - var err error - var out []byte - var nonce [12]byte - out = key2.send.Seal(out, nonce[:], testMsg, nil) - out, err = key1.receive.Open(out[:0], nonce[:], out, nil) - assertNil(t, err) - assertEqual(t, out, testMsg) - }() -} diff --git a/src/noise_types.go b/src/noise_types.go deleted file mode 100644 index 1a944df..0000000 --- a/src/noise_types.go +++ /dev/null @@ -1,74 +0,0 @@ -package main - -import ( - "crypto/subtle" - "encoding/hex" - "errors" - "golang.org/x/crypto/chacha20poly1305" -) - -const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 -) - -type ( - NoisePublicKey [NoisePublicKeySize]byte - NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [chacha20poly1305.KeySize]byte - NoiseNonce uint64 // padded to 12-bytes -) - -func loadExactHex(dst []byte, src string) error { - slice, err := hex.DecodeString(src) - if err != nil { - return err - } - if len(slice) != len(dst) { - return errors.New("Hex string does not fit the slice") - } - copy(dst, slice) - return nil -} - -func (key NoisePrivateKey) IsZero() bool { - var zero NoisePrivateKey - return key.Equals(zero) -} - -func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoisePrivateKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoisePrivateKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key *NoisePublicKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoisePublicKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key NoisePublicKey) IsZero() bool { - var zero NoisePublicKey - return key.Equals(zero) -} - -func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoiseSymmetricKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoiseSymmetricKey) ToHex() string { - return hex.EncodeToString(key[:]) -} diff --git a/src/peer.go b/src/peer.go deleted file mode 100644 index dc04811..0000000 --- a/src/peer.go +++ /dev/null @@ -1,295 +0,0 @@ -package main - -import ( - "encoding/base64" - "errors" - "fmt" - "github.com/sasha-s/go-deadlock" - "sync" - "time" -) - -const ( - PeerRoutineNumber = 4 -) - -type Peer struct { - isRunning AtomicBool - mutex deadlock.RWMutex - persistentKeepaliveInterval uint64 - keyPairs KeyPairs - handshake Handshake - device *Device - endpoint Endpoint - - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch - } - - time struct { - mutex deadlock.RWMutex - lastSend time.Time // last send message - lastHandshake time.Time // last completed handshake - nextKeepalive time.Time - } - - signal struct { - newKeyPair Signal // size 1, new key pair was generated - handshakeCompleted Signal // size 1, handshake completed - handshakeBegin Signal // size 1, begin new handshake begin - flushNonceQueue Signal // size 1, empty queued packets - messageSend Signal // size 1, message was send to peer - messageReceived Signal // size 1, authenticated message recv - } - - timer struct { - - // state related to WireGuard timers - - keepalivePersistent Timer // set for persistent keep-alive - keepalivePassive Timer // set upon receiving messages - zeroAllKeys Timer // zero all key material - handshakeNew Timer // begin a new handshake (stale) - handshakeDeadline Timer // complete handshake timeout - handshakeTimeout Timer // current handshake message timeout - - sendLastMinuteHandshake bool - needAnotherKeepalive bool - } - - queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - } - - routines struct { - mutex deadlock.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop Signal // size 0, stop all go-routines in peer - } - - mac CookieGenerator -} - -func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - - if device.isClosed.Get() { - return nil, errors.New("Device closed") - } - - // lock resources - - device.state.mutex.Lock() - defer device.state.mutex.Unlock() - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // check if over limit - - if len(device.peers.keyMap) >= MaxPeers { - return nil, errors.New("Too many peers") - } - - // create peer - - peer := new(Peer) - peer.mutex.Lock() - defer peer.mutex.Unlock() - - peer.mac.Init(pk) - peer.device = device - peer.isRunning.Set(false) - - peer.timer.zeroAllKeys = NewTimer() - peer.timer.keepalivePersistent = NewTimer() - peer.timer.keepalivePassive = NewTimer() - peer.timer.handshakeNew = NewTimer() - peer.timer.handshakeDeadline = NewTimer() - peer.timer.handshakeTimeout = NewTimer() - - // map public key - - _, ok := device.peers.keyMap[pk] - if ok { - return nil, errors.New("Adding existing peer") - } - device.peers.keyMap[pk] = peer - - // pre-compute DH - - handshake := &peer.handshake - handshake.mutex.Lock() - handshake.remoteStatic = pk - handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) - handshake.mutex.Unlock() - - // reset endpoint - - peer.endpoint = nil - - // prepare signaling & routines - - peer.routines.mutex.Lock() - peer.routines.stop = NewSignal() - peer.routines.mutex.Unlock() - - // start peer - - if peer.device.isUp.Get() { - peer.Start() - } - - return peer, nil -} - -func (peer *Peer) SendBuffer(buffer []byte) error { - peer.device.net.mutex.RLock() - defer peer.device.net.mutex.RUnlock() - - if peer.device.net.bind == nil { - return errors.New("No bind") - } - - peer.mutex.RLock() - defer peer.mutex.RUnlock() - - if peer.endpoint == nil { - return errors.New("No known endpoint for peer") - } - - return peer.device.net.bind.Send(buffer, peer.endpoint) -} - -/* Returns a short string identifier for logging - */ -func (peer *Peer) String() string { - if peer.endpoint == nil { - return fmt.Sprintf( - "peer(unknown %s)", - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) - } - return fmt.Sprintf( - "peer(%s %s)", - peer.endpoint.DstToString(), - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) -} - -func (peer *Peer) Start() { - - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { - return - } - - // prevent simultaneous start/stop operations - - peer.routines.mutex.Lock() - defer peer.routines.mutex.Unlock() - - if peer.isRunning.Get() { - return - } - - peer.device.log.Debug.Println("Starting:", peer.String()) - - // sanity check : these should be 0 - - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - - // prepare queues and signals - - peer.signal.newKeyPair = NewSignal() - peer.signal.handshakeBegin = NewSignal() - peer.signal.handshakeCompleted = NewSignal() - peer.signal.flushNonceQueue = NewSignal() - - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - - peer.routines.stop = NewSignal() - peer.isRunning.Set(true) - - // wait for routines to start - - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) - - go peer.RoutineNonce() - go peer.RoutineTimerHandler() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - - peer.routines.starting.Wait() - peer.isRunning.Set(true) -} - -func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations - - peer.routines.mutex.Lock() - defer peer.routines.mutex.Unlock() - - if !peer.isRunning.Swap(false) { - return - } - - device := peer.device - device.log.Debug.Println("Stopping:", peer.String()) - - // stop & wait for ongoing peer routines - - peer.routines.stop.Broadcast() - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - - // stop timers - - peer.timer.keepalivePersistent.Stop() - peer.timer.keepalivePassive.Stop() - peer.timer.zeroAllKeys.Stop() - peer.timer.handshakeNew.Stop() - peer.timer.handshakeDeadline.Stop() - peer.timer.handshakeTimeout.Stop() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) - - // clear key pairs - - kp := &peer.keyPairs - kp.mutex.Lock() - - device.DeleteKeyPair(kp.previous) - device.DeleteKeyPair(kp.current) - device.DeleteKeyPair(kp.next) - - kp.previous = nil - kp.current = nil - kp.next = nil - kp.mutex.Unlock() - - // clear handshake state - - hs := &peer.handshake - hs.mutex.Lock() - device.indices.Delete(hs.localIndex) - hs.Clear() - hs.mutex.Unlock() -} diff --git a/src/ratelimiter.go b/src/ratelimiter.go deleted file mode 100644 index 6e5f005..0000000 --- a/src/ratelimiter.go +++ /dev/null @@ -1,139 +0,0 @@ -package main - -/* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ - -/* This file contains a port of the ratelimited from the linux kernel version - */ - -import ( - "net" - "sync" - "time" -) - -const ( - RatelimiterPacketsPerSecond = 20 - RatelimiterPacketsBurstable = 5 - RatelimiterGarbageCollectTime = time.Second - RatelimiterPacketCost = 1000000000 / RatelimiterPacketsPerSecond - RatelimiterMaxTokens = RatelimiterPacketCost * RatelimiterPacketsBurstable -) - -type RatelimiterEntry struct { - mutex sync.Mutex - lastTime time.Time - tokens int64 -} - -type Ratelimiter struct { - mutex sync.RWMutex - lastGarbageCollect time.Time - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry -} - -func (rate *Ratelimiter) Init() { - rate.mutex.Lock() - defer rate.mutex.Unlock() - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) - rate.lastGarbageCollect = time.Now() -} - -func (rate *Ratelimiter) GarbageCollectEntries() { - rate.mutex.Lock() - - // remove unused IPv4 entries - - for key, entry := range rate.tableIPv4 { - entry.mutex.Lock() - if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.mutex.Unlock() - } - - // remove unused IPv6 entries - - for key, entry := range rate.tableIPv6 { - entry.mutex.Lock() - if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mutex.Unlock() - } - - rate.mutex.Unlock() -} - -func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) { - timer := time.NewTimer(time.Second) - for { - select { - case <-stop.Wait(): - return - case <-timer.C: - rate.GarbageCollectEntries() - timer.Reset(time.Second) - } - } -} - -func (rate *Ratelimiter) Allow(ip net.IP) bool { - var entry *RatelimiterEntry - var KeyIPv4 [net.IPv4len]byte - var KeyIPv6 [net.IPv6len]byte - - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - - rate.mutex.RLock() - - if IPv4 != nil { - copy(KeyIPv4[:], IPv4) - entry = rate.tableIPv4[KeyIPv4] - } else { - copy(KeyIPv6[:], IPv6) - entry = rate.tableIPv6[KeyIPv6] - } - - rate.mutex.RUnlock() - - // make new entry if not found - - if entry == nil { - rate.mutex.Lock() - entry = new(RatelimiterEntry) - entry.tokens = RatelimiterMaxTokens - RatelimiterPacketCost - entry.lastTime = time.Now() - if IPv4 != nil { - rate.tableIPv4[KeyIPv4] = entry - } else { - rate.tableIPv6[KeyIPv6] = entry - } - rate.mutex.Unlock() - return true - } - - // add tokens to entry - - entry.mutex.Lock() - now := time.Now() - entry.tokens += now.Sub(entry.lastTime).Nanoseconds() - entry.lastTime = now - if entry.tokens > RatelimiterMaxTokens { - entry.tokens = RatelimiterMaxTokens - } - - // subtract cost of packet - - if entry.tokens > RatelimiterPacketCost { - entry.tokens -= RatelimiterPacketCost - entry.mutex.Unlock() - return true - } - entry.mutex.Unlock() - return false -} diff --git a/src/ratelimiter_test.go b/src/ratelimiter_test.go deleted file mode 100644 index 13b6a23..0000000 --- a/src/ratelimiter_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "net" - "testing" - "time" -) - -type RatelimiterResult struct { - allowed bool - text string - wait time.Duration -} - -func TestRatelimiter(t *testing.T) { - - var ratelimiter Ratelimiter - var expectedResults []RatelimiterResult - - Nano := func(nano int64) time.Duration { - return time.Nanosecond * time.Duration(nano) - } - - Add := func(res RatelimiterResult) { - expectedResults = append( - expectedResults, - res, - ) - } - - for i := 0; i < RatelimiterPacketsBurstable; i++ { - Add(RatelimiterResult{ - allowed: true, - text: "inital burst", - }) - } - - Add(RatelimiterResult{ - allowed: false, - text: "after burst", - }) - - Add(RatelimiterResult{ - allowed: true, - wait: Nano(time.Second.Nanoseconds() / RatelimiterPacketsPerSecond), - text: "filling tokens for single packet", - }) - - Add(RatelimiterResult{ - allowed: false, - text: "not having refilled enough", - }) - - Add(RatelimiterResult{ - allowed: true, - wait: 2 * Nano(time.Second.Nanoseconds()/RatelimiterPacketsPerSecond), - text: "filling tokens for two packet burst", - }) - - Add(RatelimiterResult{ - allowed: true, - text: "second packet in 2 packet burst", - }) - - Add(RatelimiterResult{ - allowed: false, - text: "packet following 2 packet burst", - }) - - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), - } - - ratelimiter.Init() - - for i, res := range expectedResults { - time.Sleep(res.wait) - for _, ip := range ips { - allowed := ratelimiter.Allow(ip) - if allowed != res.allowed { - t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) - } - } - } -} diff --git a/src/receive.go b/src/receive.go deleted file mode 100644 index 1f44df2..0000000 --- a/src/receive.go +++ /dev/null @@ -1,642 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" - "sync" - "sync/atomic" - "time" -) - -type QueueHandshakeElement struct { - msgType uint32 - packet []byte - endpoint Endpoint - buffer *[MaxMessageSize]byte -} - -type QueueInboundElement struct { - dropped int32 - mutex sync.Mutex - buffer *[MaxMessageSize]byte - packet []byte - counter uint64 - keyPair *KeyPair - endpoint Endpoint -} - -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundQueue( - queue chan *QueueInboundElement, - element *QueueInboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - old.Drop() - default: - } - } - } -} - -func (device *Device) addToDecryptionQueue( - queue chan *QueueInboundElement, - element *QueueInboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - // drop & release to potential consumer - old.Drop() - old.mutex.Unlock() - default: - } - } - } -} - -func (device *Device) addToHandshakeQueue( - queue chan QueueHandshakeElement, - element QueueHandshakeElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case elem := <-queue: - device.PutMessageBuffer(elem.buffer) - default: - } - } - } -} - -/* Receives incoming datagrams for the device - * - * Every time the bind is updated a new routine is started for - * IPv4 and IPv6 (separately) - */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { - - logDebug := device.log.Debug - logDebug.Println("Routine, receive incoming, IP version:", IP) - - // receive datagrams until conn is closed - - buffer := device.GetMessageBuffer() - - var ( - err error - size int - endpoint Endpoint - ) - - for { - - // read next datagram - - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") - } - - if err != nil { - return - } - - if size < MinMessageSize { - continue - } - - // check size of packet - - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) - - var okay bool - - switch msgType { - - // check if transport - - case MessageTransportType: - - // check size - - if len(packet) < MessageTransportType { - continue - } - - // lookup key pair - - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indices.Lookup(receiver) - keyPair := value.keyPair - if keyPair == nil { - continue - } - - // check key-pair expiry - - if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } - - // create work element - - peer := value.peer - elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keyPair: keyPair, - dropped: AtomicFalse, - endpoint: endpoint, - } - elem.mutex.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() - } - - continue - - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize - - case MessageResponseType: - okay = len(packet) == MessageResponseSize - - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize - } - - if okay { - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - ) - buffer = device.GetMessageBuffer() - } - } -} - -func (device *Device) RoutineDecryption() { - - var nonce [chacha20poly1305.NonceSize]byte - - logDebug := device.log.Debug - logDebug.Println("Routine, decryption, started for device") - - for { - select { - case <-device.signal.stop.Wait(): - logDebug.Println("Routine, decryption worker, stopped") - return - - case elem := <-device.queue.decryption: - - // check if dropped - - if elem.IsDropped() { - continue - } - - // split message into fields - - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - - // decrypt and release to consumer - - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - elem.packet, err = elem.keyPair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.Drop() - } - elem.mutex.Unlock() - } - } -} - -/* Handles incoming packets related to handshake - */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, handshake routine, started for device") - - var temp [MessageHandshakeSize]byte - var elem QueueHandshakeElement - - for { - select { - case elem = <-device.queue.handshake: - case <-device.signal.stop.Wait(): - return - } - - // handle cookie fields and ratelimiting - - switch elem.msgType { - - case MessageCookieReplyType: - - // unmarshal packet - - var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) - if err != nil { - logDebug.Println("Failed to decode cookie reply") - return - } - - // lookup peer from index - - entry := device.indices.Lookup(reply.Receiver) - - if entry.peer == nil { - continue - } - - // consume reply - - if peer := entry.peer; peer.isRunning.Get() { - peer.mac.ConsumeReply(&reply) - } - - continue - - case MessageInitiationType, MessageResponseType: - - // check mac fields and ratelimit - - if !device.mac.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue - } - - // endpoints destination address is the source of the datagram - - srcBytes := elem.endpoint.DstToBytes() - - if device.IsUnderLoad() { - - // verify MAC2 field - - if !device.mac.CheckMAC2(elem.packet, srcBytes) { - - // construct cookie reply - - logDebug.Println( - "Sending cookie reply to:", - elem.endpoint.DstToString(), - ) - - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) - reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) - if err != nil { - logError.Println("Failed to create cookie reply:", err) - continue - } - - // marshal and send reply - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), elem.endpoint) - if err != nil { - logDebug.Println("Failed to send cookie reply:", err) - } - continue - } - - // check ratelimiter - - if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue - } - } - - default: - logError.Println("Invalid packet ended up in the handshake queue") - continue - } - - // handle handshake initiation/response content - - switch elem.msgType { - case MessageInitiationType: - - // unmarshal - - var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode initiation message") - continue - } - - // consume initiation - - peer := device.ConsumeMessageInitiation(&msg) - if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue - } - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() - - // create response - - response, err := device.CreateMessageResponse(peer) - if err != nil { - logError.Println("Failed to create response message:", err) - continue - } - - peer.TimerEphemeralKeyCreated() - peer.NewKeyPair() - - logDebug.Println("Creating response message for", peer.String()) - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send response - - err = peer.SendBuffer(packet) - if err == nil { - peer.TimerAnyAuthenticatedPacketTraversal() - } else { - logError.Println("Failed to send response to:", peer.String(), err) - } - - case MessageResponseType: - - // unmarshal - - var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode response message") - continue - } - - // consume response - - peer := device.ConsumeMessageResponse(&msg) - if peer == nil { - logInfo.Println( - "Recieved invalid response message from", - elem.endpoint.DstToString(), - ) - continue - } - - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() - - logDebug.Println("Received handshake initiation from", peer) - - peer.TimerEphemeralKeyCreated() - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.TimerHandshakeComplete() - - // derive key-pair - - peer.NewKeyPair() - peer.SendKeepAlive() - } - } -} - -func (peer *Peer) RoutineSequentialReceiver() { - - defer peer.routines.stopping.Done() - - device := peer.device - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, sequential receiver, started for peer", peer.String()) - - peer.routines.starting.Done() - - for { - - select { - - case <-peer.routines.stop.Wait(): - logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String()) - return - - case elem := <-peer.queue.inbound: - - // wait for decryption - - elem.mutex.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { - continue - } - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.KeepKeyFreshReceiving() - - // check if using new key-pair - - kp := &peer.keyPairs - kp.mutex.Lock() - if kp.next == elem.keyPair { - peer.TimerHandshakeComplete() - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - } - kp.previous = kp.current - kp.current = kp.next - kp.next = nil - } - kp.mutex.Unlock() - - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() - - // check for keep-alive - - if len(elem.packet) == 0 { - logDebug.Println("Received keep-alive from", peer.String()) - continue - } - peer.TimerDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { - continue - } - - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routing.table.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer.String(), - ) - continue - } - - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { - continue - } - - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv6 source - - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routing.table.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer.String(), - ) - continue - } - - default: - logInfo.Println("Packet with invalid IP version from", peer.String()) - continue - } - - // write to tun device - - offset := MessageTransportOffsetContent - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write( - elem.buffer[:offset+len(elem.packet)], - offset) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logError.Println("Failed to write packet to TUN device:", err) - } - } - } -} diff --git a/src/replay.go b/src/replay.go deleted file mode 100644 index 5d42860..0000000 --- a/src/replay.go +++ /dev/null @@ -1,73 +0,0 @@ -package main - -/* Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. */ - -/* Implementation of RFC6479 - * https://tools.ietf.org/html/rfc6479 - * - * The implementation is not safe for concurrent use! - */ - -const ( - // See: https://golang.org/src/math/big/arith.go - _Wordm = ^uintptr(0) - _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 - _WordSize = 1 << _WordLogSize -) - -const ( - CounterRedundantBitsLog = _WordLogSize + 3 - CounterRedundantBits = _WordSize * 8 - CounterBitsTotal = 2048 - CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) -) - -const ( - BacktrackWords = CounterBitsTotal / _WordSize -) - -type ReplayFilter struct { - counter uint64 - backtrack [BacktrackWords]uintptr -} - -func (filter *ReplayFilter) Init() { - filter.counter = 0 - filter.backtrack[0] = 0 -} - -func (filter *ReplayFilter) ValidateCounter(counter uint64) bool { - if counter >= RejectAfterMessages { - return false - } - - indexWord := counter >> CounterRedundantBitsLog - - if counter > filter.counter { - - // move window forward - - current := filter.counter >> CounterRedundantBitsLog - diff := minUint64(indexWord-current, BacktrackWords) - for i := uint64(1); i <= diff; i++ { - filter.backtrack[(current+i)%BacktrackWords] = 0 - } - filter.counter = counter - - } else if filter.counter-counter > CounterWindowSize { - - // behind current window - - return false - } - - indexWord %= BacktrackWords - indexBit := counter & uint64(CounterRedundantBits-1) - - // check and set bit - - oldValue := filter.backtrack[indexWord] - newValue := oldValue | (1 << indexBit) - filter.backtrack[indexWord] = newValue - return oldValue != newValue -} diff --git a/src/replay_test.go b/src/replay_test.go deleted file mode 100644 index 228fce6..0000000 --- a/src/replay_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package main - -import ( - "testing" -) - -/* Ported from the linux kernel implementation - * - * - */ - -func TestReplay(t *testing.T) { - var filter ReplayFilter - - T_LIM := CounterWindowSize + 1 - - testNumber := 0 - T := func(n uint64, v bool) { - testNumber++ - if filter.ValidateCounter(n) != v { - t.Fatal("Test", testNumber, "failed", n, v) - } - } - - filter.Init() - - /* 1 */ T(0, true) - /* 2 */ T(1, true) - /* 3 */ T(1, false) - /* 4 */ T(9, true) - /* 5 */ T(8, true) - /* 6 */ T(7, true) - /* 7 */ T(7, false) - /* 8 */ T(T_LIM, true) - /* 9 */ T(T_LIM-1, true) - /* 10 */ T(T_LIM-1, false) - /* 11 */ T(T_LIM-2, true) - /* 12 */ T(2, true) - /* 13 */ T(2, false) - /* 14 */ T(T_LIM+16, true) - /* 15 */ T(3, false) - /* 16 */ T(T_LIM+16, false) - /* 17 */ T(T_LIM*4, true) - /* 18 */ T(T_LIM*4-(T_LIM-1), true) - /* 19 */ T(10, false) - /* 20 */ T(T_LIM*4-T_LIM, false) - /* 21 */ T(T_LIM*4-(T_LIM+1), false) - /* 22 */ T(T_LIM*4-(T_LIM-2), true) - /* 23 */ T(T_LIM*4+1-T_LIM, false) - /* 24 */ T(0, false) - /* 25 */ T(RejectAfterMessages, false) - /* 26 */ T(RejectAfterMessages-1, true) - /* 27 */ T(RejectAfterMessages, false) - /* 28 */ T(RejectAfterMessages-1, false) - /* 29 */ T(RejectAfterMessages-2, true) - /* 30 */ T(RejectAfterMessages+1, false) - /* 31 */ T(RejectAfterMessages+2, false) - /* 32 */ T(RejectAfterMessages-2, false) - /* 33 */ T(RejectAfterMessages-3, true) - /* 34 */ T(0, false) - - t.Log("Bulk test 1") - filter.Init() - testNumber = 0 - for i := uint64(1); i <= CounterWindowSize; i++ { - T(i, true) - } - T(0, true) - T(0, false) - - t.Log("Bulk test 2") - filter.Init() - testNumber = 0 - for i := uint64(2); i <= CounterWindowSize+1; i++ { - T(i, true) - } - T(1, true) - T(0, false) - - t.Log("Bulk test 3") - filter.Init() - testNumber = 0 - for i := CounterWindowSize + 1; i > 0; i-- { - T(i, true) - } - - t.Log("Bulk test 4") - filter.Init() - testNumber = 0 - for i := CounterWindowSize + 2; i > 1; i-- { - T(i, true) - } - T(0, false) - - t.Log("Bulk test 5") - filter.Init() - testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { - T(i, true) - } - T(CounterWindowSize+1, true) - T(0, false) - - t.Log("Bulk test 6") - filter.Init() - testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { - T(i, true) - } - T(0, true) - T(CounterWindowSize+1, true) -} diff --git a/src/routing.go b/src/routing.go deleted file mode 100644 index 2a2e237..0000000 --- a/src/routing.go +++ /dev/null @@ -1,65 +0,0 @@ -package main - -import ( - "errors" - "net" - "sync" -) - -type RoutingTable struct { - IPv4 *Trie - IPv6 *Trie - mutex sync.RWMutex -} - -func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { - table.mutex.RLock() - defer table.mutex.RUnlock() - - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.AllowedIPs(peer, allowed) - allowed = table.IPv6.AllowedIPs(peer, allowed) - return allowed -} - -func (table *RoutingTable) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil -} - -func (table *RoutingTable) RemovePeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = table.IPv4.RemovePeer(peer) - table.IPv6 = table.IPv6.RemovePeer(peer) -} - -func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.Insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.Insert(ip, cidr, peer) - default: - panic(errors.New("Inserting unknown address type")) - } -} - -func (table *RoutingTable) LookupIPv4(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv4.Lookup(address) -} - -func (table *RoutingTable) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.Lookup(address) -} diff --git a/src/send.go b/src/send.go deleted file mode 100644 index 7488d3a..0000000 --- a/src/send.go +++ /dev/null @@ -1,362 +0,0 @@ -package main - -import ( - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" - "sync" - "sync/atomic" - "time" -) - -/* Outbound flow - * - * 1. TUN queue - * 2. Routing (sequential) - * 3. Nonce assignment (sequential) - * 4. Encryption (parallel) - * 5. Transmission (sequential) - * - * The functions in this file occur (roughly) in the order in - * which the packets are processed. - * - * Locking, Producers and Consumers - * - * The order of packets (per peer) must be maintained, - * but encryption of packets happen out-of-order: - * - * The sequential consumers will attempt to take the lock, - * workers release lock when they have completed work (encryption) on the packet. - * - * If the element is inserted into the "encryption queue", - * the content is preceded by enough "junk" to contain the transport header - * (to allow the construction of transport messages in-place) - */ - -type QueueOutboundElement struct { - dropped int32 - mutex sync.Mutex - buffer *[MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "buffer" (always!) - nonce uint64 // nonce for encryption - keyPair *KeyPair // key-pair for encryption - peer *Peer // related peer -} - -func (peer *Peer) FlushNonceQueue() { - elems := len(peer.queue.nonce) - for i := 0; i < elems; i++ { - select { - case <-peer.queue.nonce: - default: - return - } - } -} - -func (device *Device) NewOutboundElement() *QueueOutboundElement { - return &QueueOutboundElement{ - dropped: AtomicFalse, - buffer: device.pool.messageBuffers.Get().(*[MaxMessageSize]byte), - } -} - -func (elem *QueueOutboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func addToOutboundQueue( - queue chan *QueueOutboundElement, - element *QueueOutboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - old.Drop() - default: - } - } - } -} - -func addToEncryptionQueue( - queue chan *QueueOutboundElement, - element *QueueOutboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - // drop & release to potential consumer - old.Drop() - old.mutex.Unlock() - default: - } - } - } -} - -/* Reads packets from the TUN and inserts - * into nonce queue for peer - * - * Obs. Single instance per TUN device - */ -func (device *Device) RoutineReadFromTUN() { - - elem := device.NewOutboundElement() - - logDebug := device.log.Debug - logError := device.log.Error - - logDebug.Println("Routine, TUN Reader started") - - for { - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) - - if err != nil { - logError.Println("Failed to read packet from TUN device:", err) - device.Close() - return - } - - if size == 0 || size > MaxContentSize { - continue - } - - elem.packet = elem.buffer[offset : offset+size] - - // lookup peer - - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue - } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.routing.table.LookupIPv4(dst) - - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - continue - } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.routing.table.LookupIPv6(dst) - - default: - logDebug.Println("Received packet with unknown IP version") - } - - if peer == nil { - continue - } - - // insert into nonce/pre-handshake queue - - if peer.isRunning.Get() { - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - addToOutboundQueue(peer.queue.nonce, elem) - elem = device.NewOutboundElement() - } - } -} - -/* Queues packets when there is no handshake. - * Then assigns nonces to packets sequentially - * and creates "work" structs for workers - * - * Obs. A single instance per peer - */ -func (peer *Peer) RoutineNonce() { - var keyPair *KeyPair - - defer peer.routines.stopping.Done() - - device := peer.device - logDebug := device.log.Debug - logDebug.Println("Routine, nonce worker, started for peer", peer.String()) - - peer.routines.starting.Done() - - for { - NextPacket: - select { - case <-peer.routines.stop.Wait(): - return - - case elem := <-peer.queue.nonce: - - // wait for key pair - - for { - keyPair = peer.keyPairs.Current() - if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { - if time.Now().Sub(keyPair.created) < RejectAfterTime { - break - } - } - - peer.signal.handshakeBegin.Send() - - logDebug.Println("Awaiting key-pair for", peer.String()) - - select { - case <-peer.signal.newKeyPair.Wait(): - case <-peer.signal.flushNonceQueue.Wait(): - logDebug.Println("Clearing queue for", peer.String()) - peer.FlushNonceQueue() - goto NextPacket - case <-peer.routines.stop.Wait(): - return - } - } - - // populate work element - - elem.peer = peer - elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 - elem.keyPair = keyPair - elem.dropped = AtomicFalse - elem.mutex.Lock() - - // add to parallel and sequential queue - - addToEncryptionQueue(device.queue.encryption, elem) - addToOutboundQueue(peer.queue.outbound, elem) - } - } -} - -/* Encrypts the elements in the queue - * and marks them for sequential consumption (by releasing the mutex) - * - * Obs. One instance per core - */ -func (device *Device) RoutineEncryption() { - - var nonce [chacha20poly1305.NonceSize]byte - - logDebug := device.log.Debug - logDebug.Println("Routine, encryption worker, started") - - for { - - // fetch next element - - select { - case <-device.signal.stop.Wait(): - logDebug.Println("Routine, encryption worker, stopped") - return - - case elem := <-device.queue.encryption: - - // check if dropped - - if elem.IsDropped() { - continue - } - - // populate header fields - - header := elem.buffer[:MessageTransportHeaderSize] - - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] - - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - - // pad content to multiple of 16 - - mtu := int(atomic.LoadInt32(&device.tun.mtu)) - rem := len(elem.packet) % PaddingMultiple - if rem > 0 { - for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ { - elem.packet = append(elem.packet, 0) - } - } - - // encrypt content and release to consumer - - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keyPair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.mutex.Unlock() - } - } -} - -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { - - defer peer.routines.stopping.Done() - - device := peer.device - - logDebug := device.log.Debug - logDebug.Println("Routine, sequential sender, started for", peer.String()) - - peer.routines.starting.Done() - - for { - select { - - case <-peer.routines.stop.Wait(): - logDebug.Println( - "Routine, sequential sender, stopped for", peer.String()) - return - - case elem := <-peer.queue.outbound: - elem.mutex.Lock() - if elem.IsDropped() { - continue - } - - // send message and return buffer to pool - - length := uint64(len(elem.packet)) - err := peer.SendBuffer(elem.packet) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logDebug.Println("Failed to send authenticated packet to peer", peer.String()) - continue - } - atomic.AddUint64(&peer.stats.txBytes, length) - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - if len(elem.packet) != MessageKeepaliveSize { - peer.TimerDataSent() - } - peer.KeepKeyFreshSending() - } - } -} diff --git a/src/signal.go b/src/signal.go deleted file mode 100644 index 2cefad4..0000000 --- a/src/signal.go +++ /dev/null @@ -1,53 +0,0 @@ -package main - -type Signal struct { - enabled AtomicBool - C chan struct{} -} - -func NewSignal() (s Signal) { - s.C = make(chan struct{}, 1) - s.Enable() - return -} - -func (s *Signal) Disable() { - s.enabled.Set(false) - s.Clear() -} - -func (s *Signal) Enable() { - s.enabled.Set(true) -} - -/* Unblock exactly one listener - */ -func (s *Signal) Send() { - if s.enabled.Get() { - select { - case s.C <- struct{}{}: - default: - } - } -} - -/* Clear the signal if already fired - */ -func (s Signal) Clear() { - select { - case <-s.C: - default: - } -} - -/* Unblocks all listeners (forever) - */ -func (s Signal) Broadcast() { - close(s.C) -} - -/* Wait for the signal - */ -func (s Signal) Wait() chan struct{} { - return s.C -} diff --git a/src/tai64.go b/src/tai64.go deleted file mode 100644 index 2299a37..0000000 --- a/src/tai64.go +++ /dev/null @@ -1,28 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "time" -) - -const ( - TAI64NBase = uint64(4611686018427387914) - TAI64NSize = 12 -) - -type TAI64N [TAI64NSize]byte - -func Timestamp() TAI64N { - var tai64n TAI64N - now := time.Now() - secs := TAI64NBase + uint64(now.Unix()) - nano := uint32(now.UnixNano()) - binary.BigEndian.PutUint64(tai64n[:], secs) - binary.BigEndian.PutUint32(tai64n[8:], nano) - return tai64n -} - -func (t1 *TAI64N) After(t2 TAI64N) bool { - return bytes.Compare(t1[:], t2[:]) > 0 -} diff --git a/src/tests/netns.sh b/src/tests/netns.sh deleted file mode 100755 index 02d428b..0000000 --- a/src/tests/netns.sh +++ /dev/null @@ -1,425 +0,0 @@ -#!/bin/bash - -# Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. - -# This script tests the below topology: -# -# ┌─────────────────────┐ ┌──────────────────────────────────┐ ┌─────────────────────┐ -# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ -# │ │ │ │ │ │ -# │┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐│ -# ││ wg1 │───────────┼───┼────────────│ lo │────────────┼───┼───────────│ wg2 ││ -# │├────────┴──────────┐│ │ ┌───────┴────────┴────────┐ │ │┌──────────┴────────┤│ -# ││192.168.241.1/24 ││ │ │(ns1) (ns2) │ │ ││192.168.241.2/24 ││ -# ││fd00::1/24 ││ │ │127.0.0.1:1 127.0.0.1:2│ │ ││fd00::2/24 ││ -# │└───────────────────┘│ │ │[::]:1 [::]:2 │ │ │└───────────────────┘│ -# └─────────────────────┘ │ └─────────────────────────┘ │ └─────────────────────┘ -# └──────────────────────────────────┘ -# -# After the topology is prepared we run a series of TCP/UDP iperf3 tests between the -# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 -# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further -# details on how this is accomplished. - -# This code is ported to the WireGuard-Go directly from the kernel project. -# -# Please ensure that you have installed the newest version of the WireGuard -# tools from the WireGuard project and before running these tests as: -# -# ./netns.sh - -set -e - -exec 3>&1 -export WG_HIDE_KEYS=never -netns0="wg-test-$$-0" -netns1="wg-test-$$-1" -netns2="wg-test-$$-2" -program=$1 -export LOG_LEVEL="info" - -pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } -pp() { pretty "" "$*"; "$@"; } -maybe_exec() { if [[ $BASHPID -eq $$ ]]; then "$@"; else exec "$@"; fi; } -n0() { pretty 0 "$*"; maybe_exec ip netns exec $netns0 "$@"; } -n1() { pretty 1 "$*"; maybe_exec ip netns exec $netns1 "$@"; } -n2() { pretty 2 "$*"; maybe_exec ip netns exec $netns2 "$@"; } -ip0() { pretty 0 "ip $*"; ip -n $netns0 "$@"; } -ip1() { pretty 1 "ip $*"; ip -n $netns1 "$@"; } -ip2() { pretty 2 "ip $*"; ip -n $netns2 "$@"; } -sleep() { read -t "$1" -N 0 || true; } -waitiperf() { pretty "${1//*-}" "wait for iperf:5201"; while [[ $(ss -N "$1" -tlp 'sport = 5201') != *iperf3* ]]; do sleep 0.1; done; } -waitncatudp() { pretty "${1//*-}" "wait for udp:1111"; while [[ $(ss -N "$1" -ulp 'sport = 1111') != *ncat* ]]; do sleep 0.1; done; } -waitiface() { pretty "${1//*-}" "wait for $2 to come up"; ip netns exec "$1" bash -c "while [[ \$(< \"/sys/class/net/$2/operstate\") != up ]]; do read -t .1 -N 0 || true; done;"; } - -cleanup() { - set +e - exec 2>/dev/null - printf "$orig_message_cost" > /proc/sys/net/core/message_cost - ip0 link del dev wg1 - ip1 link del dev wg1 - ip2 link del dev wg1 - local to_kill="$(ip netns pids $netns0) $(ip netns pids $netns1) $(ip netns pids $netns2)" - [[ -n $to_kill ]] && kill $to_kill - pp ip netns del $netns1 - pp ip netns del $netns2 - pp ip netns del $netns0 - exit -} - -orig_message_cost="$(< /proc/sys/net/core/message_cost)" -trap cleanup EXIT -printf 0 > /proc/sys/net/core/message_cost - -ip netns del $netns0 2>/dev/null || true -ip netns del $netns1 2>/dev/null || true -ip netns del $netns2 2>/dev/null || true -pp ip netns add $netns0 -pp ip netns add $netns1 -pp ip netns add $netns2 -ip0 link set up dev lo - -# ip0 link add dev wg1 type wireguard -n0 $program wg1 -ip0 link set wg1 netns $netns1 - -# ip0 link add dev wg1 type wireguard -n0 $program wg2 -ip0 link set wg2 netns $netns2 - -key1="$(pp wg genkey)" -key2="$(pp wg genkey)" -pub1="$(pp wg pubkey <<<"$key1")" -pub2="$(pp wg pubkey <<<"$key2")" -psk="$(pp wg genpsk)" -[[ -n $key1 && -n $key2 && -n $psk ]] - -configure_peers() { - - ip1 addr add 192.168.241.1/24 dev wg1 - ip1 addr add fd00::1/24 dev wg1 - - ip2 addr add 192.168.241.2/24 dev wg2 - ip2 addr add fd00::2/24 dev wg2 - - n0 wg set wg1 \ - private-key <(echo "$key1") \ - listen-port 10000 \ - peer "$pub2" \ - preshared-key <(echo "$psk") \ - allowed-ips 192.168.241.2/32,fd00::2/128 - n0 wg set wg2 \ - private-key <(echo "$key2") \ - listen-port 20000 \ - peer "$pub1" \ - preshared-key <(echo "$psk") \ - allowed-ips 192.168.241.1/32,fd00::1/128 - - n0 wg showconf wg1 - n0 wg showconf wg2 - - ip1 link set up dev wg1 - ip2 link set up dev wg2 - sleep 1 -} -configure_peers - -tests() { - # Ping over IPv4 - n2 ping -c 10 -f -W 1 192.168.241.1 - n1 ping -c 10 -f -W 1 192.168.241.2 - - # Ping over IPv6 - n2 ping6 -c 10 -f -W 1 fd00::1 - n1 ping6 -c 10 -f -W 1 fd00::2 - - # TCP over IPv4 - n2 iperf3 -s -1 -B 192.168.241.2 & - waitiperf $netns2 - n1 iperf3 -Z -n 1G -c 192.168.241.2 - - # TCP over IPv6 - n1 iperf3 -s -1 -B fd00::1 & - waitiperf $netns1 - n2 iperf3 -Z -n 1G -c fd00::1 - - # UDP over IPv4 - n1 iperf3 -s -1 -B 192.168.241.1 & - waitiperf $netns1 - n2 iperf3 -Z -n 1G -b 0 -u -c 192.168.241.1 - - # UDP over IPv6 - n2 iperf3 -s -1 -B fd00::2 & - waitiperf $netns2 - n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 -} - -[[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" -big_mtu=$(( 34816 - 1500 + $orig_mtu )) - -# Test using IPv4 as outer transport -n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 -n0 wg set wg2 peer "$pub1" endpoint 127.0.0.1:10000 - -# Before calling tests, we first make sure that the stats counters are working -n2 ping -c 10 -f -W 1 192.168.241.1 -{ read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip2 -stats link show dev wg2) -ip2 -stats link show dev wg2 -n0 wg show -[[ $rx_bytes -ge 840 && $tx_bytes -ge 880 && $rx_bytes -lt 2500 && $rx_bytes -lt 2500 ]] -echo "counters working" -tests -ip1 link set wg1 mtu $big_mtu -ip2 link set wg2 mtu $big_mtu -tests - -ip1 link set wg1 mtu $orig_mtu -ip2 link set wg2 mtu $orig_mtu - -# Test using IPv6 as outer transport -n0 wg set wg1 peer "$pub2" endpoint [::1]:20000 -n0 wg set wg2 peer "$pub1" endpoint [::1]:10000 -tests -ip1 link set wg1 mtu $big_mtu -ip2 link set wg2 mtu $big_mtu -tests - -ip1 link set wg1 mtu $orig_mtu -ip2 link set wg2 mtu $orig_mtu - -# Test using IPv4 that roaming works -ip0 -4 addr del 127.0.0.1/8 dev lo -ip0 -4 addr add 127.212.121.99/8 dev lo -n0 wg set wg1 listen-port 9999 -n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 -n1 ping6 -W 1 -c 1 fd00::2 -[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] - -# Test using IPv6 that roaming works -n1 wg set wg1 listen-port 9998 -n1 wg set wg1 peer "$pub2" endpoint [::1]:20000 -n1 ping -W 1 -c 1 192.168.241.2 -[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] - -# Test that crypto-RP filter works -n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24 -exec 4< <(n1 ncat -l -u -p 1111) -nmap_pid=$! -waitncatudp $netns1 -n2 ncat -u 192.168.241.1 1111 <<<"X" -read -r -N 1 -t 1 out <&4 && [[ $out == "X" ]] -kill $nmap_pid -more_specific_key="$(pp wg genkey | pp wg pubkey)" -n0 wg set wg1 peer "$more_specific_key" allowed-ips 192.168.241.2/32 -n0 wg set wg2 listen-port 9997 -exec 4< <(n1 ncat -l -u -p 1111) -nmap_pid=$! -waitncatudp $netns1 -n2 ncat -u 192.168.241.1 1111 <<<"X" -! read -r -N 1 -t 1 out <&4 -kill $nmap_pid -n0 wg set wg1 peer "$more_specific_key" remove -[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] - -ip1 link del wg1 -ip2 link del wg2 - -# Test using NAT. We now change the topology to this: -# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────────────┐ ┌────────────────────────────────────────┐ -# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ -# │ │ │ │ │ │ -# │ ┌─────┐ ┌─────┐ │ │ ┌──────┐ ┌──────┐ │ │ ┌─────┐ ┌─────┐ │ -# │ │ wg1 │─────────────│vethc│───────────┼────┼────│vethrc│ │vethrs│──────────────┼─────┼──│veths│────────────│ wg2 │ │ -# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├──────┴─────────┐ ├──────┴────────────┐ │ │ ├─────┴──────────┐ ├─────┴──────────┐ │ -# │ │192.168.241.1/24│ │192.168.1.100/24││ │ │192.168.1.100/24│ │10.0.0.1/24 │ │ │ │10.0.0.100/24 │ │192.168.241.2/24│ │ -# │ │fd00::1/24 │ │ ││ │ │ │ │SNAT:192.168.1.0/24│ │ │ │ │ │fd00::2/24 │ │ -# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └───────────────────┘ │ │ └────────────────┘ └────────────────┘ │ -# └────────────────────────────────────────┘ └────────────────────────────────────────────────┘ └────────────────────────────────────────┘ - -# ip1 link add dev wg1 type wireguard -# ip2 link add dev wg1 type wireguard - -n1 $program wg1 -n2 $program wg2 - -configure_peers - -ip0 link add vethrc type veth peer name vethc -ip0 link add vethrs type veth peer name veths -ip0 link set vethc netns $netns1 -ip0 link set veths netns $netns2 -ip0 link set vethrc up -ip0 link set vethrs up -ip0 addr add 192.168.1.1/24 dev vethrc -ip0 addr add 10.0.0.1/24 dev vethrs -ip1 addr add 192.168.1.100/24 dev vethc -ip1 link set vethc up -ip1 route add default via 192.168.1.1 -ip2 addr add 10.0.0.100/24 dev veths -ip2 link set veths up -waitiface $netns0 vethrc -waitiface $netns0 vethrs -waitiface $netns1 vethc -waitiface $netns2 veths - -n0 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward' -n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout' -n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream' -n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to 10.0.0.1 - -n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1 -n1 ping -W 1 -c 1 192.168.241.2 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] -# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). -pp sleep 3 -n2 ping -W 1 -c 1 192.168.241.1 - -n0 iptables -t nat -F -ip0 link del vethrc -ip0 link del vethrs -ip1 link del wg1 -ip2 link del wg2 - -# Test that saddr routing is sticky but not too sticky, changing to this topology: -# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ -# │ $ns1 namespace │ │ $ns2 namespace │ -# │ │ │ │ -# │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ -# │ │ wg1 │─────────────│veth1│───────────┼────┼──│veth2│────────────│ wg2 │ │ -# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├─────┴──────────┐ ├─────┴──────────┐ │ -# │ │192.168.241.1/24│ │10.0.0.1/24 ││ │ │10.0.0.2/24 │ │192.168.241.2/24│ │ -# │ │fd00::1/24 │ │fd00:aa::1/96 ││ │ │fd00:aa::2/96 │ │fd00::2/24 │ │ -# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └────────────────┘ │ -# └────────────────────────────────────────┘ └────────────────────────────────────────┘ - -# ip1 link add dev wg1 type wireguard -# ip2 link add dev wg1 type wireguard -n1 $program wg1 -n2 $program wg2 - -configure_peers - -ip1 link add veth1 type veth peer name veth2 -ip1 link set veth2 netns $netns2 -n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth1/accept_dad' -n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth2/accept_dad' -n1 bash -c 'printf 1 > /proc/sys/net/ipv4/conf/veth1/promote_secondaries' - -# First we check that we aren't overly sticky and can fall over to new IPs when old ones are removed -ip1 addr add 10.0.0.1/24 dev veth1 -ip1 addr add fd00:aa::1/96 dev veth1 -ip2 addr add 10.0.0.2/24 dev veth2 -ip2 addr add fd00:aa::2/96 dev veth2 -ip1 link set veth1 up -ip2 link set veth2 up -waitiface $netns1 veth1 -waitiface $netns2 veth2 -n0 wg set wg1 peer "$pub2" endpoint 10.0.0.2:20000 -n1 ping -W 1 -c 1 192.168.241.2 -ip1 addr add 10.0.0.10/24 dev veth1 -ip1 addr del 10.0.0.1/24 dev veth1 -n1 ping -W 1 -c 1 192.168.241.2 -n0 wg set wg1 peer "$pub2" endpoint [fd00:aa::2]:20000 -n1 ping -W 1 -c 1 192.168.241.2 -ip1 addr add fd00:aa::10/96 dev veth1 -ip1 addr del fd00:aa::1/96 dev veth1 -n1 ping -W 1 -c 1 192.168.241.2 - -# Now we show that we can successfully do reply to sender routing -ip1 link set veth1 down -ip2 link set veth2 down -ip1 addr flush dev veth1 -ip2 addr flush dev veth2 -ip1 addr add 10.0.0.1/24 dev veth1 -ip1 addr add 10.0.0.2/24 dev veth1 -ip1 addr add fd00:aa::1/96 dev veth1 -ip1 addr add fd00:aa::2/96 dev veth1 -ip2 addr add 10.0.0.3/24 dev veth2 -ip2 addr add fd00:aa::3/96 dev veth2 -ip1 link set veth1 up -ip2 link set veth2 up -waitiface $netns1 veth1 -waitiface $netns2 veth2 -n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] -n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] -n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] -n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] - -ip1 link del veth1 -ip1 link del wg1 -ip2 link del wg2 - -# Test that Netlink/IPC is working properly by doing things that usually cause split responses - -n0 $program wg0 -sleep 5 -config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) -for a in {1..255}; do - for b in {0..255}; do - config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) - done -done -n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") -i=0 -for ip in $(n0 wg show wg0 allowed-ips); do - ((++i)) -done -((i == 255*256*2+1)) -ip0 link del wg0 - -n0 $program wg0 -config=( "[Interface]" "PrivateKey=$(wg genkey)" ) -for a in {1..40}; do - config+=( "[Peer]" "PublicKey=$(wg genkey)" ) - for b in {1..52}; do - config+=( "AllowedIPs=$a.$b.0.0/16" ) - done -done -n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") -i=0 -while read -r line; do - j=0 - for ip in $line; do - ((++j)) - done - ((j == 53)) - ((++i)) -done < <(n0 wg show wg0 allowed-ips) -((i == 40)) -ip0 link del wg0 - -n0 $program wg0 -config=( ) -for i in {1..29}; do - config+=( "[Peer]" "PublicKey=$(wg genkey)" ) -done -config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" ) -n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") -n0 wg showconf wg0 > /dev/null -ip0 link del wg0 - -! n0 wg show doesnotexist || false - -declare -A objects -while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do - [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue - objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" -done < /dev/kmsg -alldeleted=1 -for object in "${!objects[@]}"; do - if [[ ${objects["$object"]} != *createddestroyed ]]; then - echo "Error: $object: merely ${objects["$object"]}" >&3 - alldeleted=0 - fi -done -[[ $alldeleted -eq 1 ]] -pretty "" "Objects that were created were also destroyed." diff --git a/src/timer.go b/src/timer.go deleted file mode 100644 index f00ca49..0000000 --- a/src/timer.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "time" -) - -type Timer struct { - pending AtomicBool - timer *time.Timer -} - -/* Starts the timer if not already pending - */ -func (t *Timer) Start(dur time.Duration) bool { - set := t.pending.Swap(true) - if !set { - t.timer.Reset(dur) - return true - } - return false -} - -/* Stops the timer - */ -func (t *Timer) Stop() { - set := t.pending.Swap(true) - if set { - t.timer.Stop() - select { - case <-t.timer.C: - default: - } - } - t.pending.Set(false) -} - -func (t *Timer) Pending() bool { - return t.pending.Get() -} - -func (t *Timer) Reset(dur time.Duration) { - t.pending.Set(false) - t.Start(dur) -} - -func (t *Timer) Wait() <-chan time.Time { - return t.timer.C -} - -func NewTimer() (t Timer) { - t.pending.Set(false) - t.timer = time.NewTimer(0) - t.timer.Stop() - select { - case <-t.timer.C: - default: - } - return -} diff --git a/src/timers.go b/src/timers.go deleted file mode 100644 index 7092688..0000000 --- a/src/timers.go +++ /dev/null @@ -1,346 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "math/rand" - "sync/atomic" - "time" -) - -/* NOTE: - * Notion of validity - * - * - */ - -/* Called when a new authenticated message has been send - * - */ -func (peer *Peer) KeepKeyFreshSending() { - kp := peer.keyPairs.Current() - if kp == nil { - return - } - nonce := atomic.LoadUint64(&kp.sendNonce) - if nonce > RekeyAfterMessages { - peer.signal.handshakeBegin.Send() - } - if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { - peer.signal.handshakeBegin.Send() - } -} - -/* Called when a new authenticated message has been received - * - * NOTE: Not thread safe, but called by sequential receiver! - */ -func (peer *Peer) KeepKeyFreshReceiving() { - if peer.timer.sendLastMinuteHandshake { - return - } - kp := peer.keyPairs.Current() - if kp == nil { - return - } - if !kp.isInitiator { - return - } - nonce := atomic.LoadUint64(&kp.sendNonce) - send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving - if send { - // do a last minute attempt at initiating a new handshake - peer.timer.sendLastMinuteHandshake = true - peer.signal.handshakeBegin.Send() - } -} - -/* Queues a keep-alive if no packets are queued for peer - */ -func (peer *Peer) SendKeepAlive() bool { - if len(peer.queue.nonce) != 0 { - return false - } - elem := peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.nonce <- elem: - return true - default: - return false - } -} - -/* Event: - * Sent non-empty (authenticated) transport message - */ -func (peer *Peer) TimerDataSent() { - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeNew.Start(NewHandshakeTime) -} - -/* Event: - * Received non-empty (authenticated) transport message - * - * Action: - * Set a timer to confirm the message using a keep-alive (if not already set) - */ -func (peer *Peer) TimerDataReceived() { - if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) { - peer.timer.needAnotherKeepalive = true - } -} - -/* Event: - * Any (authenticated) packet received - */ -func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { - peer.timer.handshakeNew.Stop() -} - -/* Event: - * Any authenticated packet send / received. - * - * Action: - * Push persistent keep-alive into the future - */ -func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - duration := time.Duration(interval) * time.Second - peer.timer.keepalivePersistent.Reset(duration) - } -} - -/* Called after successfully completing a handshake. - * i.e. after: - * - * - Valid handshake response - * - First transport message under the "next" key - */ -func (peer *Peer) TimerHandshakeComplete() { - peer.signal.handshakeCompleted.Send() - peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) -} - -/* Event: - * An ephemeral key is generated - * - * i.e. after: - * - * CreateMessageInitiation - * CreateMessageResponse - * - * Action: - * Schedule the deletion of all key material - * upon failure to complete a handshake - */ -func (peer *Peer) TimerEphemeralKeyCreated() { - peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) -} - -/* Sends a new handshake initiation message to the peer (endpoint) - */ -func (peer *Peer) sendNewHandshake() error { - - // temporarily disable the handshake complete signal - - peer.signal.handshakeCompleted.Disable() - - // create initiation message - - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - return err - } - - // marshal handshake message - - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send to endpoint - - peer.TimerAnyAuthenticatedPacketTraversal() - - err = peer.SendBuffer(packet) - if err == nil { - peer.signal.handshakeCompleted.Enable() - } - - // set timeout - - jitter := time.Millisecond * time.Duration(rand.Uint32()%334) - - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter) - - return err -} - -func (peer *Peer) RoutineTimerHandler() { - - defer peer.routines.stopping.Done() - - device := peer.device - - logInfo := device.log.Info - logDebug := device.log.Debug - logDebug.Println("Routine, timer handler, started for peer", peer.String()) - - // reset all timers - - peer.timer.keepalivePassive.Stop() - peer.timer.handshakeDeadline.Stop() - peer.timer.handshakeTimeout.Stop() - peer.timer.handshakeNew.Stop() - peer.timer.zeroAllKeys.Stop() - - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - duration := time.Duration(interval) * time.Second - peer.timer.keepalivePersistent.Reset(duration) - } - - // signal synchronised setup complete - - peer.routines.starting.Done() - - // handle timer events - - for { - select { - - /* stopping */ - - case <-peer.routines.stop.Wait(): - return - - /* timers */ - - // keep-alive - - case <-peer.timer.keepalivePersistent.Wait(): - - interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) - if interval > 0 { - logDebug.Println("Sending keep-alive to", peer.String()) - peer.timer.keepalivePassive.Stop() - peer.SendKeepAlive() - } - - case <-peer.timer.keepalivePassive.Wait(): - - logDebug.Println("Sending keep-alive to", peer.String()) - - peer.SendKeepAlive() - - if peer.timer.needAnotherKeepalive { - peer.timer.needAnotherKeepalive = false - peer.timer.keepalivePassive.Reset(KeepaliveTimeout) - } - - // clear key material timer - - case <-peer.timer.zeroAllKeys.Wait(): - - logDebug.Println("Clearing all key material for", peer.String()) - - hs := &peer.handshake - hs.mutex.Lock() - - kp := &peer.keyPairs - kp.mutex.Lock() - - // remove key-pairs - - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - kp.previous = nil - } - if kp.current != nil { - device.DeleteKeyPair(kp.current) - kp.current = nil - } - if kp.next != nil { - device.DeleteKeyPair(kp.next) - kp.next = nil - } - kp.mutex.Unlock() - - // zero out handshake - - device.indices.Delete(hs.localIndex) - hs.Clear() - hs.mutex.Unlock() - - // handshake timers - - case <-peer.timer.handshakeNew.Wait(): - logInfo.Println("Retrying handshake with", peer.String()) - peer.signal.handshakeBegin.Send() - - case <-peer.timer.handshakeTimeout.Wait(): - - // clear source (in case this is causing problems) - - peer.mutex.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.mutex.Unlock() - - // send new handshake - - err := peer.sendNewHandshake() - if err != nil { - logInfo.Println( - "Failed to send handshake to peer:", peer.String(), "(", err, ")") - } - - case <-peer.timer.handshakeDeadline.Wait(): - - // clear all queued packets and stop keep-alive - - logInfo.Println( - "Handshake negotiation timed out for:", peer.String()) - - peer.signal.flushNonceQueue.Send() - peer.timer.keepalivePersistent.Stop() - peer.signal.handshakeBegin.Enable() - - /* signals */ - - case <-peer.signal.handshakeBegin.Wait(): - - peer.signal.handshakeBegin.Disable() - - err := peer.sendNewHandshake() - if err != nil { - logInfo.Println( - "Failed to send handshake to peer:", peer.String(), "(", err, ")") - } - - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - - case <-peer.signal.handshakeCompleted.Wait(): - - logInfo.Println( - "Handshake completed for:", peer.String()) - - atomic.StoreInt64( - &peer.stats.lastHandshakeNano, - time.Now().UnixNano(), - ) - - peer.timer.handshakeTimeout.Stop() - peer.timer.handshakeDeadline.Stop() - peer.signal.handshakeBegin.Enable() - - peer.timer.sendLastMinuteHandshake = false - } - } -} diff --git a/src/trie.go b/src/trie.go deleted file mode 100644 index 405ffc3..0000000 --- a/src/trie.go +++ /dev/null @@ -1,228 +0,0 @@ -package main - -import ( - "errors" - "net" -) - -/* Binary trie - * - * The net.IPs used here are not formatted the - * same way as those created by the "net" functions. - * Here the IPs are slices of either 4 or 16 byte (not always 16) - * - * Synchronization done separately - * See: routing.go - */ - -type Trie struct { - cidr uint - child [2]*Trie - bits []byte - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -/* Finds length of matching prefix - * - * TODO: Only use during insertion (xor + prefix mask for lookup) - * Check out - * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) - * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match - * - * Assumption: - * len(ip1) == len(ip2) - * len(ip1) mod 4 = 0 - */ -func commonBits(ip1 []byte, ip2 []byte) uint { - var i uint - size := uint(len(ip1)) - - for i = 0; i < size; i++ { - v := ip1[i] ^ ip2[i] - if v != 0 { - v >>= 1 - if v == 0 { - return i*8 + 7 - } - - v >>= 1 - if v == 0 { - return i*8 + 6 - } - - v >>= 1 - if v == 0 { - return i*8 + 5 - } - - v >>= 1 - if v == 0 { - return i*8 + 4 - } - - v >>= 1 - if v == 0 { - return i*8 + 3 - } - - v >>= 1 - if v == 0 { - return i*8 + 2 - } - - v >>= 1 - if v == 0 { - return i*8 + 1 - } - return i * 8 - } - } - return i * 8 -} - -func (node *Trie) RemovePeer(p *Peer) *Trie { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].RemovePeer(p) - node.child[1] = node.child[1].RemovePeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - -func (node *Trie) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 -} - -func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { - - // at leaf - - if node == nil { - return &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].Insert(ip, cidr, peer) - return node - } - - // split node - - newNode := &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - cidr = min(cidr, common) - - // check for shorter prefix - - if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode - } - - // create new parent for node & newNode - - parent := &Trie{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent -} - -func (node *Trie) Lookup(ip net.IP) *Peer { - var found *Peer - size := uint(len(ip)) - for node != nil && commonBits(node.bits, ip) >= node.cidr { - if node.peer != nil { - found = node.peer - } - if node.bit_at_byte == size { - break - } - bit := node.choose(ip) - node = node.child[bit] - } - return found -} - -func (node *Trie) Count() uint { - if node == nil { - return 0 - } - l := node.child[0].Count() - r := node.child[1].Count() - return l + r -} - -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - var mask net.IPNet - mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) - if len(node.bits) == net.IPv4len { - mask.IP = net.IPv4( - node.bits[0], - node.bits[1], - node.bits[2], - node.bits[3], - ) - } else if len(node.bits) == net.IPv6len { - mask.IP = node.bits - } else { - panic(errors.New("bug: unexpected address length")) - } - results = append(results, mask) - } - results = node.child[0].AllowedIPs(p, results) - results = node.child[1].AllowedIPs(p, results) - return results -} diff --git a/src/trie_rand_test.go b/src/trie_rand_test.go deleted file mode 100644 index 840d269..0000000 --- a/src/trie_rand_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package main - -import ( - "math/rand" - "sort" - "testing" -) - -const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 -) - -type SlowNode struct { - peer *Peer - cidr uint - bits []byte -} - -type SlowRouter []*SlowNode - -func (r SlowRouter) Len() int { - return len(r) -} - -func (r SlowRouter) Less(i, j int) bool { - return r[i].cidr > r[j].cidr -} - -func (r SlowRouter) Swap(i, j int) { - r[i], r[j] = r[j], r[i] -} - -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { - for _, t := range r { - if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { - t.peer = peer - t.bits = addr - return r - } - } - r = append(r, &SlowNode{ - cidr: cidr, - bits: addr, - peer: peer, - }) - sort.Sort(r) - return r -} - -func (r SlowRouter) Lookup(addr []byte) *Peer { - for _, t := range r { - common := commonBits(t.bits, addr) - if common >= t.cidr { - return t.peer - } - } - return nil -} - -func TestTrieRandomIPv4(t *testing.T) { - var trie *Trie - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} - -func TestTrieRandomIPv6(t *testing.T) { - var trie *Trie - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} diff --git a/src/trie_test.go b/src/trie_test.go deleted file mode 100644 index 9d53df3..0000000 --- a/src/trie_test.go +++ /dev/null @@ -1,255 +0,0 @@ -package main - -import ( - "math/rand" - "net" - "testing" -) - -/* Todo: More comprehensive - */ - -type testPairCommonBits struct { - s1 []byte - s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *Trie) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) -} - -func TestCommonBits(t *testing.T) { - - tests := []testPairCommonBits{ - {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, - {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, - {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, - {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, - {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, - } - - for _, p := range tests { - v := commonBits(p.s1, p.s2) - if v != p.match { - t.Error( - "For slice", p.s1, p.s2, - "expected match", p.match, - ",but got", v, - ) - } - } -} - -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { - var trie *Trie - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < peerNumber; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < addressNumber; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber - trie = trie.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < b.N; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - trie.Lookup(addr[:]) - } -} - -func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv4len, b) -} - -func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv4len, b) -} - -func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv6len, b) -} - -func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv6len, b) -} - -/* Test ported from kernel implementation: - * selftest/routingtable.h - */ -func TestTrieIPv4(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *Trie - - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.Insert([]byte{a, b, c, d}, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) - if p != peer { - t.Error("Assert EQ failed") - } - } - - assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) - if p == peer { - t.Error("Assert NEQ failed") - } - } - - insert(a, 192, 168, 4, 0, 24) - insert(b, 192, 168, 4, 4, 32) - insert(c, 192, 168, 0, 0, 16) - insert(d, 192, 95, 5, 64, 27) - insert(c, 192, 95, 5, 65, 27) - insert(e, 0, 0, 0, 0, 0) - insert(g, 64, 15, 112, 0, 20) - insert(h, 64, 15, 123, 211, 25) - insert(a, 10, 0, 0, 0, 25) - insert(b, 10, 0, 0, 128, 25) - insert(a, 10, 1, 0, 0, 30) - insert(b, 10, 1, 0, 4, 30) - insert(c, 10, 1, 0, 8, 29) - insert(d, 10, 1, 0, 16, 29) - - assertEQ(a, 192, 168, 4, 20) - assertEQ(a, 192, 168, 4, 0) - assertEQ(b, 192, 168, 4, 4) - assertEQ(c, 192, 168, 200, 182) - assertEQ(c, 192, 95, 5, 68) - assertEQ(e, 192, 95, 5, 96) - assertEQ(g, 64, 15, 116, 26) - assertEQ(g, 64, 15, 127, 3) - - insert(a, 1, 0, 0, 0, 32) - insert(a, 64, 0, 0, 0, 32) - insert(a, 128, 0, 0, 0, 32) - insert(a, 192, 0, 0, 0, 32) - insert(a, 255, 0, 0, 0, 32) - - assertEQ(a, 1, 0, 0, 0) - assertEQ(a, 64, 0, 0, 0) - assertEQ(a, 128, 0, 0, 0) - assertEQ(a, 192, 0, 0, 0) - assertEQ(a, 255, 0, 0, 0) - - trie = trie.RemovePeer(a) - - assertNEQ(a, 1, 0, 0, 0) - assertNEQ(a, 64, 0, 0, 0) - assertNEQ(a, 128, 0, 0, 0) - assertNEQ(a, 192, 0, 0, 0) - assertNEQ(a, 255, 0, 0, 0) - - trie = nil - - insert(a, 192, 168, 0, 0, 16) - insert(a, 192, 168, 0, 0, 24) - - trie = trie.RemovePeer(a) - - assertNEQ(a, 192, 168, 0, 1) -} - -/* Test ported from kernel implementation: - * selftest/routingtable.h - */ -func TestTrieIPv6(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - f := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *Trie - - expand := func(a uint32) []byte { - var out [4]byte - out[0] = byte(a >> 24 & 0xff) - out[1] = byte(a >> 16 & 0xff) - out[2] = byte(a >> 8 & 0xff) - out[3] = byte(a & 0xff) - return out[:] - } - - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - trie = trie.Insert(addr, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d uint32) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - p := trie.Lookup(addr) - if p != peer { - t.Error("Assert EQ failed") - } - } - - insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) - insert(c, 0x26075300, 0x60006b00, 0, 0, 64) - insert(e, 0, 0, 0, 0, 0) - insert(f, 0, 0, 0, 0, 0) - insert(g, 0x24046800, 0, 0, 0, 32) - insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) - insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) - insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) - insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) - - assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) - assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) - assertEQ(f, 0x26075300, 0x60006b01, 0, 0) - assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) - assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0, 0) - assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) - assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) -} diff --git a/src/tun.go b/src/tun.go deleted file mode 100644 index 6259f33..0000000 --- a/src/tun.go +++ /dev/null @@ -1,58 +0,0 @@ -package main - -import ( - "os" - "sync/atomic" -) - -const DefaultMTU = 1420 - -type TUNEvent int - -const ( - TUNEventUp = 1 << iota - TUNEventDown - TUNEventMTUUpdate -) - -type TUNDevice interface { - File() *os.File // returns the file descriptor of the device - Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) - Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) - MTU() (int, error) // returns the MTU of the device - Name() string // returns the current name - Events() chan TUNEvent // returns a constant channel of events related to the device - Close() error // stops the device and closes the event channel -} - -func (device *Device) RoutineTUNEventReader() { - logInfo := device.log.Info - logError := device.log.Error - - for event := range device.tun.device.Events() { - if event&TUNEventMTUUpdate != 0 { - mtu, err := device.tun.device.MTU() - old := atomic.LoadInt32(&device.tun.mtu) - if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - } else if int(old) != mtu { - if mtu+MessageTransportSize > MaxMessageSize { - logInfo.Println("MTU updated:", mtu, "(too large)") - } else { - logInfo.Println("MTU updated:", mtu) - } - atomic.StoreInt32(&device.tun.mtu, int32(mtu)) - } - } - - if event&TUNEventUp != 0 && !device.isUp.Get() { - logInfo.Println("Interface set up") - device.Up() - } - - if event&TUNEventDown != 0 && device.isUp.Get() { - logInfo.Println("Interface set down") - device.Down() - } - } -} diff --git a/src/tun_darwin.go b/src/tun_darwin.go deleted file mode 100644 index 87f6af6..0000000 --- a/src/tun_darwin.go +++ /dev/null @@ -1,323 +0,0 @@ -/* Copyright (c) 2016, Song Gao - * All rights reserved. - * - * Code from https://github.com/songgao/water - */ - -package main - -import ( - "encoding/binary" - "errors" - "fmt" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" - "io" - "net" - "os" - "sync" - "time" - "unsafe" -) - -const utunControlName = "com.apple.net.utun_control" - -// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h -const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3 - -// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - -// NativeTUN is a hack to work around the first 4 bytes "packet -// information" because there doesn't seem to be an IFF_NO_PI for darwin. -type NativeTUN struct { - name string - f io.ReadWriteCloser - mtu int - - rMu sync.Mutex - rBuf []byte - - wMu sync.Mutex - wBuf []byte - - events chan TUNEvent - errors chan error -} - -var sockaddrCtlSize uintptr = 32 - -func CreateTUN(name string) (ifce TUNDevice, err error) { - ifIndex := -1 - fmt.Sscanf(name, "utun%d", &ifIndex) - if ifIndex < 0 { - return nil, fmt.Errorf("error parsing interface name %s, must be utun[0-9]+", name) - } - - fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) - - if err != nil { - return nil, fmt.Errorf("error in unix.Socket: %v", err) - } - - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} - - copy(ctlInfo.ctlName[:], []byte(utunControlName)) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(_CTLIOCGINFO), - uintptr(unsafe.Pointer(ctlInfo)), - ) - - if errno != 0 { - err = errno - return nil, fmt.Errorf("error in unix.Syscall(unix.SYS_IOTL, ...): %v", err) - } - - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: 2, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, - } - - scPointer := unsafe.Pointer(&sc) - - _, _, errno = unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(scPointer), - uintptr(sockaddrCtlSize), - ) - - if errno != 0 { - err = errno - return nil, fmt.Errorf("error in unix.RawSyscall(unix.SYS_CONNECT, ...): %v", err) - } - - // read (new) name of interface - - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(16) - - _, _, errno = unix.Syscall6( - unix.SYS_GETSOCKOPT, - uintptr(fd), - 2, /* #define SYSPROTO_CONTROL 2 */ - 2, /* #define UTUN_OPT_IFNAME 2 */ - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - - if errno != 0 { - err = errno - return nil, fmt.Errorf("error in unix.Syscall6(unix.SYS_GETSOCKOPT, ...): %v", err) - } - - device := &NativeTUN{ - name: string(ifName.name[:ifNameSize-1 /* -1 is for \0 */]), - f: os.NewFile(uintptr(fd), string(ifName.name[:])), - mtu: 1500, - events: make(chan TUNEvent, 10), - errors: make(chan error, 1), - } - - // start listener - - go func(native *NativeTUN) { - // TODO: Fix this very niave implementation - var ( - statusUp bool - statusMTU int - ) - - for ; ; time.Sleep(time.Second) { - intr, err := net.InterfaceByName(device.name) - if err != nil { - native.errors <- err - return - } - - // Up / Down event - up := (intr.Flags & net.FlagUp) != 0 - if up != statusUp && up { - native.events <- TUNEventUp - } - if up != statusUp && !up { - native.events <- TUNEventDown - } - statusUp = up - - // MTU changes - if intr.MTU != statusMTU { - native.events <- TUNEventMTUUpdate - } - statusMTU = intr.MTU - } - }(device) - - // set default MTU - - err = device.setMTU(DefaultMTU) - - return device, err -} - -var _ io.ReadWriteCloser = (*NativeTUN)(nil) - -func (t *NativeTUN) Events() chan TUNEvent { - return t.events -} - -func (t *NativeTUN) Read(to []byte) (int, error) { - t.rMu.Lock() - defer t.rMu.Unlock() - - if cap(t.rBuf) < len(to)+4 { - t.rBuf = make([]byte, len(to)+4) - } - t.rBuf = t.rBuf[:len(to)+4] - - n, err := t.f.Read(t.rBuf) - copy(to, t.rBuf[4:]) - return n - 4, err -} - -func (t *NativeTUN) Write(from []byte) (int, error) { - - if len(from) == 0 { - return 0, unix.EIO - } - - t.wMu.Lock() - defer t.wMu.Unlock() - - if cap(t.wBuf) < len(from)+4 { - t.wBuf = make([]byte, len(from)+4) - } - t.wBuf = t.wBuf[:len(from)+4] - - // determine the IP Family for the NULL L2 Header - - ipVer := from[0] >> 4 - if ipVer == ipv4.Version { - t.wBuf[3] = unix.AF_INET - } else if ipVer == ipv6.Version { - t.wBuf[3] = unix.AF_INET6 - } else { - return 0, errors.New("Unable to determine IP version from packet.") - } - - copy(t.wBuf[4:], from) - - n, err := t.f.Write(t.wBuf) - return n - 4, err -} - -func (t *NativeTUN) Close() error { - - // lock to make sure no read/write is in process. - - t.rMu.Lock() - defer t.rMu.Unlock() - - t.wMu.Lock() - defer t.wMu.Unlock() - - return t.f.Close() -} - -func (t *NativeTUN) Name() string { - return t.name -} - -func (t *NativeTUN) setMTU(n int) error { - - // open datagram socket - - var fd int - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [32]byte - copy(ifr[:], t.name) - binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return fmt.Errorf("Failed to set MTU on %s", t.name) - } - - return nil -} - -func (t *NativeTUN) MTU() (int, error) { - - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return 0, err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [64]byte - copy(ifr[:], t.name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return 0, fmt.Errorf("Failed to get MTU on %s", t.name) - } - - // convert result to signed 32-bit int - - val := binary.LittleEndian.Uint32(ifr[16:20]) - if val >= (1 << 31) { - return int(val-(1<<31)) - (1 << 31), nil - } - return int(val), nil -} diff --git a/src/tun_linux.go b/src/tun_linux.go deleted file mode 100644 index 9756169..0000000 --- a/src/tun_linux.go +++ /dev/null @@ -1,377 +0,0 @@ -package main - -/* Implementation of the TUN device interface for linux - */ - -import ( - "encoding/binary" - "errors" - "fmt" - "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" - "net" - "os" - "strings" - "time" - "unsafe" -) - -// #include -// #include -// #include -// #include -// #include -// #include -// -// /* Creates a netlink socket -// * listening to the RTMGRP_LINK multicast group -// */ -// -// int bind_rtmgrp() { -// int nl_sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); -// if (nl_sock < 0) -// return -1; -// -// struct sockaddr_nl addr; -// memset ((void *) &addr, 0, sizeof (addr)); -// addr.nl_family = AF_NETLINK; -// addr.nl_pid = getpid (); -// addr.nl_groups = RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR; -// -// if (bind(nl_sock, (struct sockaddr *) &addr, sizeof (addr)) < 0) -// return -1; -// -// return nl_sock; -// } -import "C" - -const ( - CloneDevicePath = "/dev/net/tun" - IFReqSize = unix.IFNAMSIZ + 64 -) - -type NativeTun struct { - fd *os.File - index int32 // if index - name string // name of interface - errors chan error // async error handling - events chan TUNEvent // device related events -} - -func (tun *NativeTun) File() *os.File { - return tun.fd -} - -func (tun *NativeTun) RoutineHackListener() { - /* This is needed for the detection to work across network namespaces - * If you are reading this and know a better method, please get in touch. - */ - fd := int(tun.fd.Fd()) - for { - _, err := unix.Write(fd, nil) - switch err { - case unix.EINVAL: - tun.events <- TUNEventUp - case unix.EIO: - tun.events <- TUNEventDown - default: - } - time.Sleep(time.Second / 10) - } -} - -func (tun *NativeTun) RoutineNetlinkListener() { - - sock := int(C.bind_rtmgrp()) - if sock < 0 { - tun.errors <- errors.New("Failed to create netlink event listener") - return - } - - for msg := make([]byte, 1<<16); ; { - - msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0) - if err != nil { - tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error()) - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if int(hdr.Len) > len(remain) { - break - } - - switch hdr.Type { - case unix.NLMSG_DONE: - remain = []byte{} - - case unix.RTM_NEWLINK: - info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr])) - remain = remain[hdr.Len:] - - if info.Index != tun.index { - // not our interface - continue - } - - if info.Flags&unix.IFF_RUNNING != 0 { - tun.events <- TUNEventUp - } - - if info.Flags&unix.IFF_RUNNING == 0 { - tun.events <- TUNEventDown - } - - tun.events <- TUNEventMTUUpdate - - default: - remain = remain[hdr.Len:] - } - } - } -} - -func (tun *NativeTun) isUp() (bool, error) { - inter, err := net.InterfaceByName(tun.name) - return inter.Flags&net.FlagUp != 0, err -} - -func (tun *NativeTun) Name() string { - return tun.name -} - -func getDummySock() (int, error) { - return unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) -} - -func getIFIndex(name string) (int32, error) { - fd, err := getDummySock() - if err != nil { - return 0, err - } - - defer unix.Close(fd) - - var ifr [IFReqSize]byte - copy(ifr[:], name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFINDEX), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return 0, errno - } - - index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) - return toInt32(index), nil -} - -func (tun *NativeTun) setMTU(n int) error { - - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [IFReqSize]byte - copy(ifr[:], tun.name) - binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return errors.New("Failed to set MTU of TUN device") - } - - return nil -} - -func (tun *NativeTun) MTU() (int, error) { - - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return 0, err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [IFReqSize]byte - copy(ifr[:], tun.name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return 0, errors.New("Failed to get MTU of TUN device") - } - - // convert result to signed 32-bit int - - val := binary.LittleEndian.Uint32(ifr[16:20]) - if val >= (1 << 31) { - return int(toInt32(val)), nil - } - return int(val), nil -} - -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - - if buff[4] == ipv6.Version<<4 { - buff[2] = 0x86 - buff[3] = 0xdd - } else { - buff[2] = 0x08 - buff[3] = 0x00 - } - - // write - - return tun.fd.Write(buff) -} - -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { - select { - case err := <-tun.errors: - return 0, err - default: - buff := buff[offset-4:] - n, err := tun.fd.Read(buff[:]) - if n < 4 { - return 0, err - } - return n - 4, err - } -} - -func (tun *NativeTun) Events() chan TUNEvent { - return tun.events -} - -func (tun *NativeTun) Close() error { - return nil -} - -func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) { - device := &NativeTun{ - fd: fd, - name: name, - events: make(chan TUNEvent, 5), - errors: make(chan error, 5), - } - - // start event listener - - var err error - device.index, err = getIFIndex(device.name) - if err != nil { - return nil, err - } - - go device.RoutineNetlinkListener() - // go device.RoutineHackListener() // cross namespace - - // set default MTU - - return device, device.setMTU(DefaultMTU) -} - -func CreateTUN(name string) (TUNDevice, error) { - - // open clone device - - fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0) - if err != nil { - return nil, err - } - - // create new device - - var ifr [IFReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - return nil, errors.New("Interface name too long") - } - copy(ifr[:], nameBytes) - binary.LittleEndian.PutUint16(ifr[16:], flags) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd.Fd()), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return nil, errno - } - - // read (new) name of interface - - newName := string(ifr[:]) - newName = newName[:strings.Index(newName, "\000")] - device := &NativeTun{ - fd: fd, - name: newName, - events: make(chan TUNEvent, 5), - errors: make(chan error, 5), - } - - // start event listener - - device.index, err = getIFIndex(device.name) - if err != nil { - return nil, err - } - - go device.RoutineNetlinkListener() - // go device.RoutineHackListener() // cross namespace - - // set default MTU - - return device, device.setMTU(DefaultMTU) -} diff --git a/src/tun_windows.go b/src/tun_windows.go deleted file mode 100644 index 0711032..0000000 --- a/src/tun_windows.go +++ /dev/null @@ -1,475 +0,0 @@ -package main - -import ( - "encoding/binary" - "errors" - "fmt" - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" - "net" - "sync" - "syscall" - "time" - "unsafe" -) - -/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version) - * - * https://github.com/OpenVPN/tap-windows - */ - -type NativeTUN struct { - fd windows.Handle - rl sync.Mutex - wl sync.Mutex - ro *windows.Overlapped - wo *windows.Overlapped - events chan TUNEvent - name string -} - -const ( - METHOD_BUFFERED = 0 - ComponentID = "tap0901" // tap0801 -) - -func ctl_code(device_type, function, method, access uint32) uint32 { - return (device_type << 16) | (access << 14) | (function << 2) | method -} - -func TAP_CONTROL_CODE(request, method uint32) uint32 { - return ctl_code(file_device_unknown, request, method, 0) -} - -var ( - errIfceNameNotFound = errors.New("Failed to find the name of interface") - - TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED) - TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED) - TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED) - TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED) - TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED) - TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED) - - file_device_unknown = uint32(0x00000022) - nCreateEvent, - nResetEvent, - nGetOverlappedResult uintptr -) - -func init() { - k32, err := windows.LoadLibrary("kernel32.dll") - if err != nil { - panic("LoadLibrary " + err.Error()) - } - defer windows.FreeLibrary(k32) - nCreateEvent = getProcAddr(k32, "CreateEventW") - nResetEvent = getProcAddr(k32, "ResetEvent") - nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult") -} - -/* implementation of the read/write/closer interface */ - -func getProcAddr(lib windows.Handle, name string) uintptr { - addr, err := windows.GetProcAddress(lib, name) - if err != nil { - panic(name + " " + err.Error()) - } - return addr -} - -func resetEvent(h windows.Handle) error { - r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0) - if r == 0 { - return err - } - return nil -} - -func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) { - var n int - r, _, err := syscall.Syscall6( - nGetOverlappedResult, - 4, - uintptr(h), - uintptr(unsafe.Pointer(overlapped)), - uintptr(unsafe.Pointer(&n)), 1, 0, 0) - - if r == 0 { - return n, err - } - return n, nil -} - -func newOverlapped() (*windows.Overlapped, error) { - var overlapped windows.Overlapped - r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0) - if r == 0 { - return nil, err - } - overlapped.HEvent = windows.Handle(r) - return &overlapped, nil -} - -func (f *NativeTUN) Events() chan TUNEvent { - return f.events -} - -func (f *NativeTUN) Close() error { - return windows.Close(f.fd) -} - -func (f *NativeTUN) Write(b []byte) (int, error) { - f.wl.Lock() - defer f.wl.Unlock() - - if err := resetEvent(f.wo.HEvent); err != nil { - return 0, err - } - var n uint32 - err := windows.WriteFile(f.fd, b, &n, f.wo) - if err != nil && err != windows.ERROR_IO_PENDING { - return int(n), err - } - return getOverlappedResult(f.fd, f.wo) -} - -func (f *NativeTUN) Read(b []byte) (int, error) { - f.rl.Lock() - defer f.rl.Unlock() - - if err := resetEvent(f.ro.HEvent); err != nil { - return 0, err - } - var done uint32 - err := windows.ReadFile(f.fd, b, &done, f.ro) - if err != nil && err != windows.ERROR_IO_PENDING { - return int(done), err - } - return getOverlappedResult(f.fd, f.ro) -} - -func getdeviceid( - targetComponentId string, - targetDeviceName string, -) (deviceid string, err error) { - - getName := func(instanceId string) (string, error) { - path := fmt.Sprintf( - `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`, - instanceId, - ) - - key, err := registry.OpenKey( - registry.LOCAL_MACHINE, - path, - registry.READ, - ) - - if err != nil { - return "", err - } - defer key.Close() - - val, _, err := key.GetStringValue("Name") - key.Close() - return val, err - } - - getInstanceId := func(keyName string) (string, string, error) { - path := fmt.Sprintf( - `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`, - keyName, - ) - - key, err := registry.OpenKey( - registry.LOCAL_MACHINE, - path, - registry.READ, - ) - - if err != nil { - return "", "", err - } - defer key.Close() - - componentId, _, err := key.GetStringValue("ComponentId") - if err != nil { - return "", "", err - } - - instanceId, _, err := key.GetStringValue("NetCfgInstanceId") - - return componentId, instanceId, err - } - - // find list of all network devices - - k, err := registry.OpenKey( - registry.LOCAL_MACHINE, - `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`, - registry.READ, - ) - - if err != nil { - return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err) - } - - defer k.Close() - - keys, err := k.ReadSubKeyNames(-1) - - if err != nil { - return "", err - } - - // look for matching component id and name - - var componentFound bool - - for _, v := range keys { - - componentId, instanceId, err := getInstanceId(v) - if err != nil || componentId != targetComponentId { - continue - } - - componentFound = true - - deviceName, err := getName(instanceId) - if err != nil || deviceName != targetDeviceName { - continue - } - - return instanceId, nil - } - - // provide a descriptive error message - - if componentFound { - return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName) - } - - return "", fmt.Errorf( - "Unable to find device in registry with ComponentId = %s, is tap-windows installed?", - targetComponentId, - ) -} - -// setStatus is used to bring up or bring down the interface -func setStatus(fd windows.Handle, status bool) error { - var code [4]byte - if status { - binary.LittleEndian.PutUint32(code[:], 1) - } - - var bytesReturned uint32 - rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) - return windows.DeviceIoControl( - fd, - TAP_IOCTL_SET_MEDIA_STATUS, - &code[0], - uint32(4), - &rdbbuf[0], - uint32(len(rdbbuf)), - &bytesReturned, - nil, - ) -} - -/* When operating in TUN mode we must assign an ip address & subnet to the device. - * - */ -func setTUN(fd windows.Handle, network string) error { - var bytesReturned uint32 - rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) - localIP, remoteNet, err := net.ParseCIDR(network) - - if err != nil { - return fmt.Errorf("Failed to parse network CIDR in config, %v", err) - } - - if localIP.To4() == nil { - return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network) - } - - var param [12]byte - - copy(param[0:4], localIP.To4()) - copy(param[4:8], remoteNet.IP.To4()) - copy(param[8:12], remoteNet.Mask) - - return windows.DeviceIoControl( - fd, - TAP_IOCTL_CONFIG_TUN, - ¶m[0], - uint32(12), - &rdbbuf[0], - uint32(len(rdbbuf)), - &bytesReturned, - nil, - ) -} - -func (tun *NativeTUN) MTU() (int, error) { - var mtu [4]byte - var bytesReturned uint32 - err := windows.DeviceIoControl( - tun.fd, - TAP_IOCTL_GET_MTU, - &mtu[0], - uint32(len(mtu)), - &mtu[0], - uint32(len(mtu)), - &bytesReturned, - nil, - ) - val := binary.LittleEndian.Uint32(mtu[:]) - return int(val), err -} - -func (tun *NativeTUN) Name() string { - return tun.name -} - -func CreateTUN(name string) (TUNDevice, error) { - - // find the device in registry. - - deviceid, err := getdeviceid(ComponentID, name) - if err != nil { - return nil, err - } - path := "\\\\.\\Global\\" + deviceid + ".tap" - pathp, err := windows.UTF16PtrFromString(path) - if err != nil { - return nil, err - } - - // create TUN device - - handle, err := windows.CreateFile( - pathp, - windows.GENERIC_READ|windows.GENERIC_WRITE, - 0, - nil, - windows.OPEN_EXISTING, - windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED, - 0, - ) - - if err != nil { - return nil, err - } - - ro, err := newOverlapped() - if err != nil { - windows.Close(handle) - return nil, err - } - - wo, err := newOverlapped() - if err != nil { - windows.Close(handle) - return nil, err - } - - tun := &NativeTUN{ - fd: handle, - name: name, - ro: ro, - wo: wo, - events: make(chan TUNEvent, 5), - } - - // find addresses of interface - // TODO: fix this hack, the question is how - - inter, err := net.InterfaceByName(name) - if err != nil { - windows.Close(handle) - return nil, err - } - - addrs, err := inter.Addrs() - if err != nil { - windows.Close(handle) - return nil, err - } - - var ip net.IP - for _, addr := range addrs { - ip = func() net.IP { - switch v := addr.(type) { - case *net.IPNet: - return v.IP.To4() - case *net.IPAddr: - return v.IP.To4() - } - return nil - }() - if ip != nil { - break - } - } - - if ip == nil { - windows.Close(handle) - return nil, errors.New("No IPv4 address found for interface") - } - - // bring up device. - - if err := setStatus(handle, true); err != nil { - windows.Close(handle) - return nil, err - } - - // set tun mode - - mask := ip.String() + "/0" - if err := setTUN(handle, mask); err != nil { - windows.Close(handle) - return nil, err - } - - // start listener - - go func(native *NativeTUN, ifname string) { - // TODO: Fix this very niave implementation - var ( - statusUp bool - statusMTU int - ) - - for ; ; time.Sleep(time.Second) { - intr, err := net.InterfaceByName(name) - if err != nil { - // TODO: handle - return - } - - // Up / Down event - up := (intr.Flags & net.FlagUp) != 0 - if up != statusUp && up { - native.events <- TUNEventUp - } - if up != statusUp && !up { - native.events <- TUNEventDown - } - statusUp = up - - // MTU changes - if intr.MTU != statusMTU { - native.events <- TUNEventMTUUpdate - } - statusMTU = intr.MTU - } - }(tun, name) - - return tun, nil -} diff --git a/src/uapi.go b/src/uapi.go deleted file mode 100644 index caaa498..0000000 --- a/src/uapi.go +++ /dev/null @@ -1,437 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "io" - "net" - "strconv" - "strings" - "sync/atomic" - "time" -) - -type IPCError struct { - Code int64 -} - -func (s *IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.Code) -} - -func (s *IPCError) ErrorCode() int64 { - return s.Code -} - -func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - - device.log.Debug.Println("UAPI: Processing get operation") - - // create lines - - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) - } - - func() { - - // lock required resources - - device.net.mutex.RLock() - defer device.net.mutex.RUnlock() - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - device.routing.mutex.RLock() - defer device.routing.mutex.RUnlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // serialize device related values - - if !device.noise.privateKey.IsZero() { - send("private_key=" + device.noise.privateKey.ToHex()) - } - - if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) - } - - if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) - } - - // serialize each peer state - - for _, peer := range device.peers.keyMap { - peer.mutex.RLock() - defer peer.mutex.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) - } - - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() - - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", peer.stats.txBytes)) - send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) - send(fmt.Sprintf("persistent_keepalive_interval=%d", - atomic.LoadUint64(&peer.persistentKeepaliveInterval), - )) - - for _, ip := range device.routing.table.AllowedIPs(peer) { - send("allowed_ip=" + ip.String()) - } - - } - }() - - // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ - Code: ipcErrorIO, - } - } - } - - return nil -} - -func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug - - var peer *Peer - - dummy := false - deviceConfig := true - - for scanner.Scan() { - - // parse line - - line := scanner.Text() - if line == "" { - return nil - } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{Code: ipcErrorProtocol} - } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{Code: ipcErrorInvalid} - } - logDebug.Println("UAPI: Updating device private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.mutex.Lock() - device.net.port = uint16(port) - device.net.mutex.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorPortInUse} - } - - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{Code: ipcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - device.net.mutex.Lock() - device.net.fwmark = uint32(fwmark) - device.net.mutex.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{Code: ipcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") - deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{Code: ipcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI key (device configuration):", key) - return &IPCError{Code: ipcErrorInvalid} - } - } - - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public_key:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - // ignore peer with public key of device - - device.noise.mutex.RLock() - equals := device.noise.publicKey.Equals(publicKey) - device.noise.mutex.RUnlock() - - if equals { - peer = &Peer{} - dummy = true - } - - // find peer referenced - - peer = device.LookupPeer(publicKey) - - if peer == nil { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{Code: ipcErrorInvalid} - } - logDebug.Println("UAPI: Created new peer:", peer.String()) - } - - peer.mutex.Lock() - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - peer.mutex.Unlock() - - case "remove": - - // remove currently selected peer from device - - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{Code: ipcErrorInvalid} - } - if !dummy { - logDebug.Println("UAPI: Removing peer:", peer.String()) - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true - - case "preshared_key": - - // update PSK - - logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) - - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() - - if err != nil { - logError.Println("Failed to set preshared_key:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - case "endpoint": - - // set endpoint destination - - logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) - - err := func() error { - peer.mutex.Lock() - defer peer.mutex.Unlock() - endpoint, err := CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - return nil - }() - - if err != nil { - logError.Println("Failed to set endpoint:", value) - return &IPCError{Code: ipcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update keep-alive interval - - logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent_keepalive_interval:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - old := atomic.SwapUint64( - &peer.persistentKeepaliveInterval, - secs, - ) - - // send immediate keep-alive - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{Code: ipcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepAlive() - } - } - - case "replace_allowed_ips": - - logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) - - if value != "true" { - logError.Println("Failed to set replace_allowed_ips, invalid value:", value) - return &IPCError{Code: ipcErrorInvalid} - } - - if dummy { - continue - } - - device.routing.mutex.Lock() - device.routing.table.RemovePeer(peer) - device.routing.mutex.Unlock() - - case "allowed_ip": - - logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) - - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed_ip:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - if dummy { - continue - } - - ones, _ := network.Mask.Size() - device.routing.mutex.Lock() - device.routing.table.Insert(network.IP, uint(ones), peer) - device.routing.mutex.Unlock() - - default: - logError.Println("Invalid UAPI key (peer configuration):", key) - return &IPCError{Code: ipcErrorInvalid} - } - } - } - - return nil -} - -func ipcHandle(device *Device, socket net.Conn) { - - // create buffered read/writer - - defer socket.Close() - - buffered := func(s io.ReadWriter) *bufio.ReadWriter { - reader := bufio.NewReader(s) - writer := bufio.NewWriter(s) - return bufio.NewReadWriter(reader, writer) - }(socket) - - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError - - switch op { - case "set=1\n": - device.log.Debug.Println("Config, set operation") - status = ipcSetOperation(device, buffered) - - case "get=1\n": - device.log.Debug.Println("Config, get operation") - status = ipcGetOperation(device, buffered) - - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } - - // write status - - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } -} diff --git a/src/uapi_darwin.go b/src/uapi_darwin.go deleted file mode 100644 index 63d4d8d..0000000 --- a/src/uapi_darwin.go +++ /dev/null @@ -1,99 +0,0 @@ -package main - -import ( - "fmt" - "golang.org/x/sys/unix" - "net" - "os" - "path" - "time" -) - -const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - return l.listener.Close() -} - -func (l *UAPIListener) Addr() net.Addr { - return nil -} - -func NewUAPIListener(name string) (net.Listener, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 077) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - listener, err := net.Listen("unix", socketPath) - if err != nil { - return nil, err - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - // watch for deletion of socket - - go func(l *UAPIListener) { - for ; ; time.Sleep(time.Second) { - if _, err := os.Stat(socketPath); os.IsNotExist(err) { - l.connErr <- err - return - } - } - }(uapi) - - // watch for new connections - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} diff --git a/src/uapi_linux.go b/src/uapi_linux.go deleted file mode 100644 index f97a18a..0000000 --- a/src/uapi_linux.go +++ /dev/null @@ -1,171 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "golang.org/x/sys/unix" - "net" - "os" - "path" -) - -const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error - inotifyFd int -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - err1 := unix.Close(l.inotifyFd) - err2 := l.listener.Close() - if err1 != nil { - return err1 - } - return err2 -} - -func (l *UAPIListener) Addr() net.Addr { - return nil -} - -func UAPIListen(name string, file *os.File) (net.Listener, error) { - - // wrap file in listener - - listener, err := net.FileListener(file) - if err != nil { - return nil, err - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - // watch for deletion of socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - uapi.inotifyFd, err = unix.InotifyInit() - if err != nil { - return nil, err - } - - _, err = unix.InotifyAddWatch( - uapi.inotifyFd, - socketPath, - unix.IN_ATTRIB| - unix.IN_DELETE| - unix.IN_DELETE_SELF, - ) - - if err != nil { - return nil, err - } - - go func(l *UAPIListener) { - var buff [4096]byte - for { - // start with lstat to avoid race condition - if _, err := os.Lstat(socketPath); os.IsNotExist(err) { - l.connErr <- err - return - } - unix.Read(uapi.inotifyFd, buff[:]) - } - }(uapi) - - // watch for new connections - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} - -func UAPIOpen(name string) (*os.File, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 0600) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - addr, err := net.ResolveUnixAddr("unix", socketPath) - if err != nil { - return nil, err - } - - listener, err := func() (*net.UnixListener, error) { - - // initial connection attempt - - listener, err := net.ListenUnix("unix", addr) - if err == nil { - return listener, nil - } - - // check if socket already active - - _, err = net.Dial("unix", socketPath) - if err == nil { - return nil, errors.New("unix socket in use") - } - - // cleanup & attempt again - - err = os.Remove(socketPath) - if err != nil { - return nil, err - } - return net.ListenUnix("unix", addr) - }() - - if err != nil { - return nil, err - } - - return listener.File() -} diff --git a/src/uapi_windows.go b/src/uapi_windows.go deleted file mode 100644 index a4599a5..0000000 --- a/src/uapi_windows.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -/* UAPI on windows uses a bidirectional named pipe - */ - -import ( - "fmt" - "github.com/Microsoft/go-winio" - "golang.org/x/sys/windows" - "net" -) - -const ( - ipcErrorIO = -int64(windows.ERROR_BROKEN_PIPE) - ipcErrorProtocol = -int64(windows.ERROR_INVALID_NAME) - ipcErrorInvalid = -int64(windows.ERROR_INVALID_PARAMETER) - ipcErrorPortInUse = -int64(windows.ERROR_ALREADY_EXISTS) -) - -const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s" - -type UAPIListener struct { - listener net.Listener -} - -func (uapi *UAPIListener) Accept() (net.Conn, error) { - return nil, nil -} - -func (uapi *UAPIListener) Close() error { - return uapi.listener.Close() -} - -func (uapi *UAPIListener) Addr() net.Addr { - return nil -} - -func NewUAPIListener(name string) (net.Listener, error) { - path := fmt.Sprintf(PipeNameFmt, name) - return winio.ListenPipe(path, &winio.PipeConfig{ - InputBufferSize: 2048, - OutputBufferSize: 2048, - }) -} diff --git a/src/xchacha20.go b/src/xchacha20.go deleted file mode 100644 index 5d963e0..0000000 --- a/src/xchacha20.go +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. -// Use of this source code is governed by a license that can be -// found in the LICENSE file. - -package main - -import ( - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" -) - -func HChaCha20(out *[32]byte, nonce []byte, key *[32]byte) { - - v00 := uint32(0x61707865) - v01 := uint32(0x3320646e) - v02 := uint32(0x79622d32) - v03 := uint32(0x6b206574) - - v04 := binary.LittleEndian.Uint32(key[0:]) - v05 := binary.LittleEndian.Uint32(key[4:]) - v06 := binary.LittleEndian.Uint32(key[8:]) - v07 := binary.LittleEndian.Uint32(key[12:]) - v08 := binary.LittleEndian.Uint32(key[16:]) - v09 := binary.LittleEndian.Uint32(key[20:]) - v10 := binary.LittleEndian.Uint32(key[24:]) - v11 := binary.LittleEndian.Uint32(key[28:]) - v12 := binary.LittleEndian.Uint32(nonce[0:]) - v13 := binary.LittleEndian.Uint32(nonce[4:]) - v14 := binary.LittleEndian.Uint32(nonce[8:]) - v15 := binary.LittleEndian.Uint32(nonce[12:]) - - for i := 0; i < 20; i += 2 { - v00 += v04 - v12 ^= v00 - v12 = (v12 << 16) | (v12 >> 16) - v08 += v12 - v04 ^= v08 - v04 = (v04 << 12) | (v04 >> 20) - v00 += v04 - v12 ^= v00 - v12 = (v12 << 8) | (v12 >> 24) - v08 += v12 - v04 ^= v08 - v04 = (v04 << 7) | (v04 >> 25) - v01 += v05 - v13 ^= v01 - v13 = (v13 << 16) | (v13 >> 16) - v09 += v13 - v05 ^= v09 - v05 = (v05 << 12) | (v05 >> 20) - v01 += v05 - v13 ^= v01 - v13 = (v13 << 8) | (v13 >> 24) - v09 += v13 - v05 ^= v09 - v05 = (v05 << 7) | (v05 >> 25) - v02 += v06 - v14 ^= v02 - v14 = (v14 << 16) | (v14 >> 16) - v10 += v14 - v06 ^= v10 - v06 = (v06 << 12) | (v06 >> 20) - v02 += v06 - v14 ^= v02 - v14 = (v14 << 8) | (v14 >> 24) - v10 += v14 - v06 ^= v10 - v06 = (v06 << 7) | (v06 >> 25) - v03 += v07 - v15 ^= v03 - v15 = (v15 << 16) | (v15 >> 16) - v11 += v15 - v07 ^= v11 - v07 = (v07 << 12) | (v07 >> 20) - v03 += v07 - v15 ^= v03 - v15 = (v15 << 8) | (v15 >> 24) - v11 += v15 - v07 ^= v11 - v07 = (v07 << 7) | (v07 >> 25) - v00 += v05 - v15 ^= v00 - v15 = (v15 << 16) | (v15 >> 16) - v10 += v15 - v05 ^= v10 - v05 = (v05 << 12) | (v05 >> 20) - v00 += v05 - v15 ^= v00 - v15 = (v15 << 8) | (v15 >> 24) - v10 += v15 - v05 ^= v10 - v05 = (v05 << 7) | (v05 >> 25) - v01 += v06 - v12 ^= v01 - v12 = (v12 << 16) | (v12 >> 16) - v11 += v12 - v06 ^= v11 - v06 = (v06 << 12) | (v06 >> 20) - v01 += v06 - v12 ^= v01 - v12 = (v12 << 8) | (v12 >> 24) - v11 += v12 - v06 ^= v11 - v06 = (v06 << 7) | (v06 >> 25) - v02 += v07 - v13 ^= v02 - v13 = (v13 << 16) | (v13 >> 16) - v08 += v13 - v07 ^= v08 - v07 = (v07 << 12) | (v07 >> 20) - v02 += v07 - v13 ^= v02 - v13 = (v13 << 8) | (v13 >> 24) - v08 += v13 - v07 ^= v08 - v07 = (v07 << 7) | (v07 >> 25) - v03 += v04 - v14 ^= v03 - v14 = (v14 << 16) | (v14 >> 16) - v09 += v14 - v04 ^= v09 - v04 = (v04 << 12) | (v04 >> 20) - v03 += v04 - v14 ^= v03 - v14 = (v14 << 8) | (v14 >> 24) - v09 += v14 - v04 ^= v09 - v04 = (v04 << 7) | (v04 >> 25) - } - - binary.LittleEndian.PutUint32(out[0:], v00) - binary.LittleEndian.PutUint32(out[4:], v01) - binary.LittleEndian.PutUint32(out[8:], v02) - binary.LittleEndian.PutUint32(out[12:], v03) - binary.LittleEndian.PutUint32(out[16:], v12) - binary.LittleEndian.PutUint32(out[20:], v13) - binary.LittleEndian.PutUint32(out[24:], v14) - binary.LittleEndian.PutUint32(out[28:], v15) -} - -func XChaCha20Poly1305Encrypt( - dst []byte, - nonceFull *[24]byte, - plaintext []byte, - additionalData []byte, - key *[chacha20poly1305.KeySize]byte, -) []byte { - var nonce [chacha20poly1305.NonceSize]byte - var derivedKey [chacha20poly1305.KeySize]byte - HChaCha20(&derivedKey, nonceFull[:16], key) - aead, _ := chacha20poly1305.New(derivedKey[:]) - copy(nonce[4:], nonceFull[16:]) - return aead.Seal(dst, nonce[:], plaintext, additionalData) -} - -func XChaCha20Poly1305Decrypt( - dst []byte, - nonceFull *[24]byte, - plaintext []byte, - additionalData []byte, - key *[chacha20poly1305.KeySize]byte, -) ([]byte, error) { - var nonce [chacha20poly1305.NonceSize]byte - var derivedKey [chacha20poly1305.KeySize]byte - HChaCha20(&derivedKey, nonceFull[:16], key) - aead, _ := chacha20poly1305.New(derivedKey[:]) - copy(nonce[4:], nonceFull[16:]) - return aead.Open(dst, nonce[:], plaintext, additionalData) -} diff --git a/src/xchacha20_test.go b/src/xchacha20_test.go deleted file mode 100644 index 0f41cf8..0000000 --- a/src/xchacha20_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package main - -import ( - "encoding/hex" - "testing" -) - -type XChaCha20Test struct { - Nonce string - Key string - PT string - CT string -} - -func TestXChaCha20(t *testing.T) { - - tests := []XChaCha20Test{ - { - Nonce: "000000000000000000000000000000000000000000000000", - Key: "0000000000000000000000000000000000000000000000000000000000000000", - PT: "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - CT: "789e9689e5208d7fd9e1f3c5b5341f48ef18a13e418998addadd97a3693a987f8e82ecd5c1433bfed1af49750c0f1ff29c4174a05b119aa3a9e8333812e0c0feb1299c5949d895ee01dbf50f8395dd84", - }, - { - Nonce: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - Key: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - PT: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - CT: "e1a046aa7f71e2af8b80b6408b2fd8d3a350278cde79c94d9efaa475e1339b3dd490127b", - }, - { - Nonce: "d9a8213e8a697508805c2c171ad54487ead9e3e02d82d5bc", - Key: "979196dbd78526f2f584f7534db3f5824d8ccfa858ca7e09bdd3656ecd36033c", - PT: "43cc6d624e451bbed952c3e071dc6c03392ce11eb14316a94b2fdc98b22fedea", - CT: "53c1e8bef2dbb8f2505ec010a7afe21d5a8e6dd8f987e4ea1a2ed5dfbc844ea400db34496fd2153526c6e87c36694200", - }, - } - - for _, test := range tests { - - nonce, err := hex.DecodeString(test.Nonce) - if err != nil { - panic(err) - } - - key, err := hex.DecodeString(test.Key) - if err != nil { - panic(err) - } - - pt, err := hex.DecodeString(test.PT) - if err != nil { - panic(err) - } - - func() { - var nonceArray [24]byte - var keyArray [32]byte - copy(nonceArray[:], nonce) - copy(keyArray[:], key) - - // test encryption - - ct := XChaCha20Poly1305Encrypt( - nil, - &nonceArray, - pt, - nil, - &keyArray, - ) - ctHex := hex.EncodeToString(ct) - if ctHex != test.CT { - t.Fatal("encryption failed, expected:", test.CT, "got", ctHex) - } - - // test decryption - - ptp, err := XChaCha20Poly1305Decrypt( - nil, - &nonceArray, - ct, - nil, - &keyArray, - ) - if err != nil { - t.Fatal(err) - } - - ptHex := hex.EncodeToString(ptp) - if ptHex != test.PT { - t.Fatal("decryption failed, expected:", test.PT, "got", ptHex) - } - }() - - } - -} diff --git a/tai64.go b/tai64.go new file mode 100644 index 0000000..2299a37 --- /dev/null +++ b/tai64.go @@ -0,0 +1,28 @@ +package main + +import ( + "bytes" + "encoding/binary" + "time" +) + +const ( + TAI64NBase = uint64(4611686018427387914) + TAI64NSize = 12 +) + +type TAI64N [TAI64NSize]byte + +func Timestamp() TAI64N { + var tai64n TAI64N + now := time.Now() + secs := TAI64NBase + uint64(now.Unix()) + nano := uint32(now.UnixNano()) + binary.BigEndian.PutUint64(tai64n[:], secs) + binary.BigEndian.PutUint32(tai64n[8:], nano) + return tai64n +} + +func (t1 *TAI64N) After(t2 TAI64N) bool { + return bytes.Compare(t1[:], t2[:]) > 0 +} diff --git a/tests/netns.sh b/tests/netns.sh new file mode 100755 index 0000000..6c47a44 --- /dev/null +++ b/tests/netns.sh @@ -0,0 +1,425 @@ +#!/bin/bash + +# Copyright (C) 2015-2017 Jason A. Donenfeld . All Rights Reserved. + +# This script tests the below topology: +# +# ┌─────────────────────┐ ┌──────────────────────────────────┐ ┌─────────────────────┐ +# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ +# │ │ │ │ │ │ +# │┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐│ +# ││ wg1 │───────────┼───┼────────────│ lo │────────────┼───┼───────────│ wg2 ││ +# │├────────┴──────────┐│ │ ┌───────┴────────┴────────┐ │ │┌──────────┴────────┤│ +# ││192.168.241.1/24 ││ │ │(ns1) (ns2) │ │ ││192.168.241.2/24 ││ +# ││fd00::1/24 ││ │ │127.0.0.1:1 127.0.0.1:2│ │ ││fd00::2/24 ││ +# │└───────────────────┘│ │ │[::]:1 [::]:2 │ │ │└───────────────────┘│ +# └─────────────────────┘ │ └─────────────────────────┘ │ └─────────────────────┘ +# └──────────────────────────────────┘ +# +# After the topology is prepared we run a series of TCP/UDP iperf3 tests between the +# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 +# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further +# details on how this is accomplished. + +# This code is ported to the WireGuard-Go directly from the kernel project. +# +# Please ensure that you have installed the newest version of the WireGuard +# tools from the WireGuard project and before running these tests as: +# +# ./netns.sh + +set -e + +exec 3>&1 +export WG_HIDE_KEYS=never +netns0="wg-test-$$-0" +netns1="wg-test-$$-1" +netns2="wg-test-$$-2" +program=$1 +export LOG_LEVEL="info" + +pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } +pp() { pretty "" "$*"; "$@"; } +maybe_exec() { if [[ $BASHPID -eq $$ ]]; then "$@"; else exec "$@"; fi; } +n0() { pretty 0 "$*"; maybe_exec ip netns exec $netns0 "$@"; } +n1() { pretty 1 "$*"; maybe_exec ip netns exec $netns1 "$@"; } +n2() { pretty 2 "$*"; maybe_exec ip netns exec $netns2 "$@"; } +ip0() { pretty 0 "ip $*"; ip -n $netns0 "$@"; } +ip1() { pretty 1 "ip $*"; ip -n $netns1 "$@"; } +ip2() { pretty 2 "ip $*"; ip -n $netns2 "$@"; } +sleep() { read -t "$1" -N 0 || true; } +waitiperf() { pretty "${1//*-}" "wait for iperf:5201"; while [[ $(ss -N "$1" -tlp 'sport = 5201') != *iperf3* ]]; do sleep 0.1; done; } +waitncatudp() { pretty "${1//*-}" "wait for udp:1111"; while [[ $(ss -N "$1" -ulp 'sport = 1111') != *ncat* ]]; do sleep 0.1; done; } +waitiface() { pretty "${1//*-}" "wait for $2 to come up"; ip netns exec "$1" bash -c "while [[ \$(< \"/sys/class/net/$2/operstate\") != up ]]; do read -t .1 -N 0 || true; done;"; } + +cleanup() { + set +e + exec 2>/dev/null + printf "$orig_message_cost" > /proc/sys/net/core/message_cost + ip0 link del dev wg1 + ip1 link del dev wg1 + ip2 link del dev wg1 + local to_kill="$(ip netns pids $netns0) $(ip netns pids $netns1) $(ip netns pids $netns2)" + [[ -n $to_kill ]] && kill $to_kill + pp ip netns del $netns1 + pp ip netns del $netns2 + pp ip netns del $netns0 + exit +} + +orig_message_cost="$(< /proc/sys/net/core/message_cost)" +trap cleanup EXIT +printf 0 > /proc/sys/net/core/message_cost + +ip netns del $netns0 2>/dev/null || true +ip netns del $netns1 2>/dev/null || true +ip netns del $netns2 2>/dev/null || true +pp ip netns add $netns0 +pp ip netns add $netns1 +pp ip netns add $netns2 +ip0 link set up dev lo + +# ip0 link add dev wg1 type wireguard +n0 $program -f wg1 & +ip0 link set wg1 netns $netns1 + +# ip0 link add dev wg1 type wireguard +n0 $program -f wg2 & +ip0 link set wg2 netns $netns2 + +key1="$(pp wg genkey)" +key2="$(pp wg genkey)" +pub1="$(pp wg pubkey <<<"$key1")" +pub2="$(pp wg pubkey <<<"$key2")" +psk="$(pp wg genpsk)" +[[ -n $key1 && -n $key2 && -n $psk ]] + +configure_peers() { + + ip1 addr add 192.168.241.1/24 dev wg1 + ip1 addr add fd00::1/24 dev wg1 + + ip2 addr add 192.168.241.2/24 dev wg2 + ip2 addr add fd00::2/24 dev wg2 + + n0 wg set wg1 \ + private-key <(echo "$key1") \ + listen-port 10000 \ + peer "$pub2" \ + preshared-key <(echo "$psk") \ + allowed-ips 192.168.241.2/32,fd00::2/128 + n0 wg set wg2 \ + private-key <(echo "$key2") \ + listen-port 20000 \ + peer "$pub1" \ + preshared-key <(echo "$psk") \ + allowed-ips 192.168.241.1/32,fd00::1/128 + + n0 wg showconf wg1 + n0 wg showconf wg2 + + ip1 link set up dev wg1 + ip2 link set up dev wg2 + sleep 1 +} +configure_peers + +tests() { + # Ping over IPv4 + n2 ping -c 10 -f -W 1 192.168.241.1 + n1 ping -c 10 -f -W 1 192.168.241.2 + + # Ping over IPv6 + n2 ping6 -c 10 -f -W 1 fd00::1 + n1 ping6 -c 10 -f -W 1 fd00::2 + + # TCP over IPv4 + n2 iperf3 -s -1 -B 192.168.241.2 & + waitiperf $netns2 + n1 iperf3 -Z -n 1G -c 192.168.241.2 + + # TCP over IPv6 + n1 iperf3 -s -1 -B fd00::1 & + waitiperf $netns1 + n2 iperf3 -Z -n 1G -c fd00::1 + + # UDP over IPv4 + n1 iperf3 -s -1 -B 192.168.241.1 & + waitiperf $netns1 + n2 iperf3 -Z -n 1G -b 0 -u -c 192.168.241.1 + + # UDP over IPv6 + n2 iperf3 -s -1 -B fd00::2 & + waitiperf $netns2 + n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 +} + +[[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" +big_mtu=$(( 34816 - 1500 + $orig_mtu )) + +# Test using IPv4 as outer transport +n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 +n0 wg set wg2 peer "$pub1" endpoint 127.0.0.1:10000 + +# Before calling tests, we first make sure that the stats counters are working +n2 ping -c 10 -f -W 1 192.168.241.1 +{ read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip2 -stats link show dev wg2) +ip2 -stats link show dev wg2 +n0 wg show +[[ $rx_bytes -ge 840 && $tx_bytes -ge 880 && $rx_bytes -lt 2500 && $rx_bytes -lt 2500 ]] +echo "counters working" +tests +ip1 link set wg1 mtu $big_mtu +ip2 link set wg2 mtu $big_mtu +tests + +ip1 link set wg1 mtu $orig_mtu +ip2 link set wg2 mtu $orig_mtu + +# Test using IPv6 as outer transport +n0 wg set wg1 peer "$pub2" endpoint [::1]:20000 +n0 wg set wg2 peer "$pub1" endpoint [::1]:10000 +tests +ip1 link set wg1 mtu $big_mtu +ip2 link set wg2 mtu $big_mtu +tests + +ip1 link set wg1 mtu $orig_mtu +ip2 link set wg2 mtu $orig_mtu + +# Test using IPv4 that roaming works +ip0 -4 addr del 127.0.0.1/8 dev lo +ip0 -4 addr add 127.212.121.99/8 dev lo +n0 wg set wg1 listen-port 9999 +n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 +n1 ping6 -W 1 -c 1 fd00::2 +[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] + +# Test using IPv6 that roaming works +n1 wg set wg1 listen-port 9998 +n1 wg set wg1 peer "$pub2" endpoint [::1]:20000 +n1 ping -W 1 -c 1 192.168.241.2 +[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] + +# Test that crypto-RP filter works +n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24 +exec 4< <(n1 ncat -l -u -p 1111) +nmap_pid=$! +waitncatudp $netns1 +n2 ncat -u 192.168.241.1 1111 <<<"X" +read -r -N 1 -t 1 out <&4 && [[ $out == "X" ]] +kill $nmap_pid +more_specific_key="$(pp wg genkey | pp wg pubkey)" +n0 wg set wg1 peer "$more_specific_key" allowed-ips 192.168.241.2/32 +n0 wg set wg2 listen-port 9997 +exec 4< <(n1 ncat -l -u -p 1111) +nmap_pid=$! +waitncatudp $netns1 +n2 ncat -u 192.168.241.1 1111 <<<"X" +! read -r -N 1 -t 1 out <&4 +kill $nmap_pid +n0 wg set wg1 peer "$more_specific_key" remove +[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] + +ip1 link del wg1 +ip2 link del wg2 + +# Test using NAT. We now change the topology to this: +# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────────────┐ ┌────────────────────────────────────────┐ +# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ +# │ │ │ │ │ │ +# │ ┌─────┐ ┌─────┐ │ │ ┌──────┐ ┌──────┐ │ │ ┌─────┐ ┌─────┐ │ +# │ │ wg1 │─────────────│vethc│───────────┼────┼────│vethrc│ │vethrs│──────────────┼─────┼──│veths│────────────│ wg2 │ │ +# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├──────┴─────────┐ ├──────┴────────────┐ │ │ ├─────┴──────────┐ ├─────┴──────────┐ │ +# │ │192.168.241.1/24│ │192.168.1.100/24││ │ │192.168.1.100/24│ │10.0.0.1/24 │ │ │ │10.0.0.100/24 │ │192.168.241.2/24│ │ +# │ │fd00::1/24 │ │ ││ │ │ │ │SNAT:192.168.1.0/24│ │ │ │ │ │fd00::2/24 │ │ +# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └───────────────────┘ │ │ └────────────────┘ └────────────────┘ │ +# └────────────────────────────────────────┘ └────────────────────────────────────────────────┘ └────────────────────────────────────────┘ + +# ip1 link add dev wg1 type wireguard +# ip2 link add dev wg1 type wireguard + +n1 $program wg1 +n2 $program wg2 + +configure_peers + +ip0 link add vethrc type veth peer name vethc +ip0 link add vethrs type veth peer name veths +ip0 link set vethc netns $netns1 +ip0 link set veths netns $netns2 +ip0 link set vethrc up +ip0 link set vethrs up +ip0 addr add 192.168.1.1/24 dev vethrc +ip0 addr add 10.0.0.1/24 dev vethrs +ip1 addr add 192.168.1.100/24 dev vethc +ip1 link set vethc up +ip1 route add default via 192.168.1.1 +ip2 addr add 10.0.0.100/24 dev veths +ip2 link set veths up +waitiface $netns0 vethrc +waitiface $netns0 vethrs +waitiface $netns1 vethc +waitiface $netns2 veths + +n0 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward' +n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout' +n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream' +n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to 10.0.0.1 + +n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1 +n1 ping -W 1 -c 1 192.168.241.2 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] +# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). +pp sleep 3 +n2 ping -W 1 -c 1 192.168.241.1 + +n0 iptables -t nat -F +ip0 link del vethrc +ip0 link del vethrs +ip1 link del wg1 +ip2 link del wg2 + +# Test that saddr routing is sticky but not too sticky, changing to this topology: +# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ +# │ $ns1 namespace │ │ $ns2 namespace │ +# │ │ │ │ +# │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ +# │ │ wg1 │─────────────│veth1│───────────┼────┼──│veth2│────────────│ wg2 │ │ +# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├─────┴──────────┐ ├─────┴──────────┐ │ +# │ │192.168.241.1/24│ │10.0.0.1/24 ││ │ │10.0.0.2/24 │ │192.168.241.2/24│ │ +# │ │fd00::1/24 │ │fd00:aa::1/96 ││ │ │fd00:aa::2/96 │ │fd00::2/24 │ │ +# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └────────────────┘ │ +# └────────────────────────────────────────┘ └────────────────────────────────────────┘ + +# ip1 link add dev wg1 type wireguard +# ip2 link add dev wg1 type wireguard +n1 $program wg1 +n2 $program wg2 + +configure_peers + +ip1 link add veth1 type veth peer name veth2 +ip1 link set veth2 netns $netns2 +n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth1/accept_dad' +n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth2/accept_dad' +n1 bash -c 'printf 1 > /proc/sys/net/ipv4/conf/veth1/promote_secondaries' + +# First we check that we aren't overly sticky and can fall over to new IPs when old ones are removed +ip1 addr add 10.0.0.1/24 dev veth1 +ip1 addr add fd00:aa::1/96 dev veth1 +ip2 addr add 10.0.0.2/24 dev veth2 +ip2 addr add fd00:aa::2/96 dev veth2 +ip1 link set veth1 up +ip2 link set veth2 up +waitiface $netns1 veth1 +waitiface $netns2 veth2 +n0 wg set wg1 peer "$pub2" endpoint 10.0.0.2:20000 +n1 ping -W 1 -c 1 192.168.241.2 +ip1 addr add 10.0.0.10/24 dev veth1 +ip1 addr del 10.0.0.1/24 dev veth1 +n1 ping -W 1 -c 1 192.168.241.2 +n0 wg set wg1 peer "$pub2" endpoint [fd00:aa::2]:20000 +n1 ping -W 1 -c 1 192.168.241.2 +ip1 addr add fd00:aa::10/96 dev veth1 +ip1 addr del fd00:aa::1/96 dev veth1 +n1 ping -W 1 -c 1 192.168.241.2 + +# Now we show that we can successfully do reply to sender routing +ip1 link set veth1 down +ip2 link set veth2 down +ip1 addr flush dev veth1 +ip2 addr flush dev veth2 +ip1 addr add 10.0.0.1/24 dev veth1 +ip1 addr add 10.0.0.2/24 dev veth1 +ip1 addr add fd00:aa::1/96 dev veth1 +ip1 addr add fd00:aa::2/96 dev veth1 +ip2 addr add 10.0.0.3/24 dev veth2 +ip2 addr add fd00:aa::3/96 dev veth2 +ip1 link set veth1 up +ip2 link set veth2 up +waitiface $netns1 veth1 +waitiface $netns2 veth2 +n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] +n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] +n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] +n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] + +ip1 link del veth1 +ip1 link del wg1 +ip2 link del wg2 + +# Test that Netlink/IPC is working properly by doing things that usually cause split responses + +n0 $program wg0 +sleep 5 +config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) +for a in {1..255}; do + for b in {0..255}; do + config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) + done +done +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +i=0 +for ip in $(n0 wg show wg0 allowed-ips); do + ((++i)) +done +((i == 255*256*2+1)) +ip0 link del wg0 + +n0 $program wg0 +config=( "[Interface]" "PrivateKey=$(wg genkey)" ) +for a in {1..40}; do + config+=( "[Peer]" "PublicKey=$(wg genkey)" ) + for b in {1..52}; do + config+=( "AllowedIPs=$a.$b.0.0/16" ) + done +done +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +i=0 +while read -r line; do + j=0 + for ip in $line; do + ((++j)) + done + ((j == 53)) + ((++i)) +done < <(n0 wg show wg0 allowed-ips) +((i == 40)) +ip0 link del wg0 + +n0 $program wg0 +config=( ) +for i in {1..29}; do + config+=( "[Peer]" "PublicKey=$(wg genkey)" ) +done +config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" ) +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +n0 wg showconf wg0 > /dev/null +ip0 link del wg0 + +! n0 wg show doesnotexist || false + +declare -A objects +while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do + [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue + objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" +done < /dev/kmsg +alldeleted=1 +for object in "${!objects[@]}"; do + if [[ ${objects["$object"]} != *createddestroyed ]]; then + echo "Error: $object: merely ${objects["$object"]}" >&3 + alldeleted=0 + fi +done +[[ $alldeleted -eq 1 ]] +pretty "" "Objects that were created were also destroyed." diff --git a/timer.go b/timer.go new file mode 100644 index 0000000..f00ca49 --- /dev/null +++ b/timer.go @@ -0,0 +1,59 @@ +package main + +import ( + "time" +) + +type Timer struct { + pending AtomicBool + timer *time.Timer +} + +/* Starts the timer if not already pending + */ +func (t *Timer) Start(dur time.Duration) bool { + set := t.pending.Swap(true) + if !set { + t.timer.Reset(dur) + return true + } + return false +} + +/* Stops the timer + */ +func (t *Timer) Stop() { + set := t.pending.Swap(true) + if set { + t.timer.Stop() + select { + case <-t.timer.C: + default: + } + } + t.pending.Set(false) +} + +func (t *Timer) Pending() bool { + return t.pending.Get() +} + +func (t *Timer) Reset(dur time.Duration) { + t.pending.Set(false) + t.Start(dur) +} + +func (t *Timer) Wait() <-chan time.Time { + return t.timer.C +} + +func NewTimer() (t Timer) { + t.pending.Set(false) + t.timer = time.NewTimer(0) + t.timer.Stop() + select { + case <-t.timer.C: + default: + } + return +} diff --git a/timers.go b/timers.go new file mode 100644 index 0000000..7092688 --- /dev/null +++ b/timers.go @@ -0,0 +1,346 @@ +package main + +import ( + "bytes" + "encoding/binary" + "math/rand" + "sync/atomic" + "time" +) + +/* NOTE: + * Notion of validity + * + * + */ + +/* Called when a new authenticated message has been send + * + */ +func (peer *Peer) KeepKeyFreshSending() { + kp := peer.keyPairs.Current() + if kp == nil { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + if nonce > RekeyAfterMessages { + peer.signal.handshakeBegin.Send() + } + if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { + peer.signal.handshakeBegin.Send() + } +} + +/* Called when a new authenticated message has been received + * + * NOTE: Not thread safe, but called by sequential receiver! + */ +func (peer *Peer) KeepKeyFreshReceiving() { + if peer.timer.sendLastMinuteHandshake { + return + } + kp := peer.keyPairs.Current() + if kp == nil { + return + } + if !kp.isInitiator { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving + if send { + // do a last minute attempt at initiating a new handshake + peer.timer.sendLastMinuteHandshake = true + peer.signal.handshakeBegin.Send() + } +} + +/* Queues a keep-alive if no packets are queued for peer + */ +func (peer *Peer) SendKeepAlive() bool { + if len(peer.queue.nonce) != 0 { + return false + } + elem := peer.device.NewOutboundElement() + elem.packet = nil + select { + case peer.queue.nonce <- elem: + return true + default: + return false + } +} + +/* Event: + * Sent non-empty (authenticated) transport message + */ +func (peer *Peer) TimerDataSent() { + peer.timer.keepalivePassive.Stop() + peer.timer.handshakeNew.Start(NewHandshakeTime) +} + +/* Event: + * Received non-empty (authenticated) transport message + * + * Action: + * Set a timer to confirm the message using a keep-alive (if not already set) + */ +func (peer *Peer) TimerDataReceived() { + if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) { + peer.timer.needAnotherKeepalive = true + } +} + +/* Event: + * Any (authenticated) packet received + */ +func (peer *Peer) TimerAnyAuthenticatedPacketReceived() { + peer.timer.handshakeNew.Stop() +} + +/* Event: + * Any authenticated packet send / received. + * + * Action: + * Push persistent keep-alive into the future + */ +func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + duration := time.Duration(interval) * time.Second + peer.timer.keepalivePersistent.Reset(duration) + } +} + +/* Called after successfully completing a handshake. + * i.e. after: + * + * - Valid handshake response + * - First transport message under the "next" key + */ +func (peer *Peer) TimerHandshakeComplete() { + peer.signal.handshakeCompleted.Send() + peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) +} + +/* Event: + * An ephemeral key is generated + * + * i.e. after: + * + * CreateMessageInitiation + * CreateMessageResponse + * + * Action: + * Schedule the deletion of all key material + * upon failure to complete a handshake + */ +func (peer *Peer) TimerEphemeralKeyCreated() { + peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) +} + +/* Sends a new handshake initiation message to the peer (endpoint) + */ +func (peer *Peer) sendNewHandshake() error { + + // temporarily disable the handshake complete signal + + peer.signal.handshakeCompleted.Disable() + + // create initiation message + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + return err + } + + // marshal handshake message + + var buff [MessageInitiationSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + // send to endpoint + + peer.TimerAnyAuthenticatedPacketTraversal() + + err = peer.SendBuffer(packet) + if err == nil { + peer.signal.handshakeCompleted.Enable() + } + + // set timeout + + jitter := time.Millisecond * time.Duration(rand.Uint32()%334) + + peer.timer.keepalivePassive.Stop() + peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter) + + return err +} + +func (peer *Peer) RoutineTimerHandler() { + + defer peer.routines.stopping.Done() + + device := peer.device + + logInfo := device.log.Info + logDebug := device.log.Debug + logDebug.Println("Routine, timer handler, started for peer", peer.String()) + + // reset all timers + + peer.timer.keepalivePassive.Stop() + peer.timer.handshakeDeadline.Stop() + peer.timer.handshakeTimeout.Stop() + peer.timer.handshakeNew.Stop() + peer.timer.zeroAllKeys.Stop() + + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + duration := time.Duration(interval) * time.Second + peer.timer.keepalivePersistent.Reset(duration) + } + + // signal synchronised setup complete + + peer.routines.starting.Done() + + // handle timer events + + for { + select { + + /* stopping */ + + case <-peer.routines.stop.Wait(): + return + + /* timers */ + + // keep-alive + + case <-peer.timer.keepalivePersistent.Wait(): + + interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval) + if interval > 0 { + logDebug.Println("Sending keep-alive to", peer.String()) + peer.timer.keepalivePassive.Stop() + peer.SendKeepAlive() + } + + case <-peer.timer.keepalivePassive.Wait(): + + logDebug.Println("Sending keep-alive to", peer.String()) + + peer.SendKeepAlive() + + if peer.timer.needAnotherKeepalive { + peer.timer.needAnotherKeepalive = false + peer.timer.keepalivePassive.Reset(KeepaliveTimeout) + } + + // clear key material timer + + case <-peer.timer.zeroAllKeys.Wait(): + + logDebug.Println("Clearing all key material for", peer.String()) + + hs := &peer.handshake + hs.mutex.Lock() + + kp := &peer.keyPairs + kp.mutex.Lock() + + // remove key-pairs + + if kp.previous != nil { + device.DeleteKeyPair(kp.previous) + kp.previous = nil + } + if kp.current != nil { + device.DeleteKeyPair(kp.current) + kp.current = nil + } + if kp.next != nil { + device.DeleteKeyPair(kp.next) + kp.next = nil + } + kp.mutex.Unlock() + + // zero out handshake + + device.indices.Delete(hs.localIndex) + hs.Clear() + hs.mutex.Unlock() + + // handshake timers + + case <-peer.timer.handshakeNew.Wait(): + logInfo.Println("Retrying handshake with", peer.String()) + peer.signal.handshakeBegin.Send() + + case <-peer.timer.handshakeTimeout.Wait(): + + // clear source (in case this is causing problems) + + peer.mutex.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.mutex.Unlock() + + // send new handshake + + err := peer.sendNewHandshake() + if err != nil { + logInfo.Println( + "Failed to send handshake to peer:", peer.String(), "(", err, ")") + } + + case <-peer.timer.handshakeDeadline.Wait(): + + // clear all queued packets and stop keep-alive + + logInfo.Println( + "Handshake negotiation timed out for:", peer.String()) + + peer.signal.flushNonceQueue.Send() + peer.timer.keepalivePersistent.Stop() + peer.signal.handshakeBegin.Enable() + + /* signals */ + + case <-peer.signal.handshakeBegin.Wait(): + + peer.signal.handshakeBegin.Disable() + + err := peer.sendNewHandshake() + if err != nil { + logInfo.Println( + "Failed to send handshake to peer:", peer.String(), "(", err, ")") + } + + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + + case <-peer.signal.handshakeCompleted.Wait(): + + logInfo.Println( + "Handshake completed for:", peer.String()) + + atomic.StoreInt64( + &peer.stats.lastHandshakeNano, + time.Now().UnixNano(), + ) + + peer.timer.handshakeTimeout.Stop() + peer.timer.handshakeDeadline.Stop() + peer.signal.handshakeBegin.Enable() + + peer.timer.sendLastMinuteHandshake = false + } + } +} diff --git a/trie.go b/trie.go new file mode 100644 index 0000000..405ffc3 --- /dev/null +++ b/trie.go @@ -0,0 +1,228 @@ +package main + +import ( + "errors" + "net" +) + +/* Binary trie + * + * The net.IPs used here are not formatted the + * same way as those created by the "net" functions. + * Here the IPs are slices of either 4 or 16 byte (not always 16) + * + * Synchronization done separately + * See: routing.go + */ + +type Trie struct { + cidr uint + child [2]*Trie + bits []byte + peer *Peer + + // index of "branching" bit + + bit_at_byte uint + bit_at_shift uint +} + +/* Finds length of matching prefix + * + * TODO: Only use during insertion (xor + prefix mask for lookup) + * Check out + * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) + * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match + * + * Assumption: + * len(ip1) == len(ip2) + * len(ip1) mod 4 = 0 + */ +func commonBits(ip1 []byte, ip2 []byte) uint { + var i uint + size := uint(len(ip1)) + + for i = 0; i < size; i++ { + v := ip1[i] ^ ip2[i] + if v != 0 { + v >>= 1 + if v == 0 { + return i*8 + 7 + } + + v >>= 1 + if v == 0 { + return i*8 + 6 + } + + v >>= 1 + if v == 0 { + return i*8 + 5 + } + + v >>= 1 + if v == 0 { + return i*8 + 4 + } + + v >>= 1 + if v == 0 { + return i*8 + 3 + } + + v >>= 1 + if v == 0 { + return i*8 + 2 + } + + v >>= 1 + if v == 0 { + return i*8 + 1 + } + return i * 8 + } + } + return i * 8 +} + +func (node *Trie) RemovePeer(p *Peer) *Trie { + if node == nil { + return node + } + + // walk recursively + + node.child[0] = node.child[0].RemovePeer(p) + node.child[1] = node.child[1].RemovePeer(p) + + if node.peer != p { + return node + } + + // remove peer & merge + + node.peer = nil + if node.child[0] == nil { + return node.child[1] + } + return node.child[0] +} + +func (node *Trie) choose(ip net.IP) byte { + return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +} + +func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { + + // at leaf + + if node == nil { + return &Trie{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + } + + // traverse deeper + + common := commonBits(node.bits, ip) + if node.cidr <= cidr && common >= node.cidr { + if node.cidr == cidr { + node.peer = peer + return node + } + bit := node.choose(ip) + node.child[bit] = node.child[bit].Insert(ip, cidr, peer) + return node + } + + // split node + + newNode := &Trie{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + cidr = min(cidr, common) + + // check for shorter prefix + + if newNode.cidr == cidr { + bit := newNode.choose(node.bits) + newNode.child[bit] = node + return newNode + } + + // create new parent for node & newNode + + parent := &Trie{ + bits: ip, + peer: nil, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + bit := parent.choose(ip) + parent.child[bit] = newNode + parent.child[bit^1] = node + + return parent +} + +func (node *Trie) Lookup(ip net.IP) *Peer { + var found *Peer + size := uint(len(ip)) + for node != nil && commonBits(node.bits, ip) >= node.cidr { + if node.peer != nil { + found = node.peer + } + if node.bit_at_byte == size { + break + } + bit := node.choose(ip) + node = node.child[bit] + } + return found +} + +func (node *Trie) Count() uint { + if node == nil { + return 0 + } + l := node.child[0].Count() + r := node.child[1].Count() + return l + r +} + +func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { + if node == nil { + return results + } + if node.peer == p { + var mask net.IPNet + mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) + if len(node.bits) == net.IPv4len { + mask.IP = net.IPv4( + node.bits[0], + node.bits[1], + node.bits[2], + node.bits[3], + ) + } else if len(node.bits) == net.IPv6len { + mask.IP = node.bits + } else { + panic(errors.New("bug: unexpected address length")) + } + results = append(results, mask) + } + results = node.child[0].AllowedIPs(p, results) + results = node.child[1].AllowedIPs(p, results) + return results +} diff --git a/trie_rand_test.go b/trie_rand_test.go new file mode 100644 index 0000000..840d269 --- /dev/null +++ b/trie_rand_test.go @@ -0,0 +1,126 @@ +package main + +import ( + "math/rand" + "sort" + "testing" +) + +const ( + NumberOfPeers = 100 + NumberOfAddresses = 250 + NumberOfTests = 10000 +) + +type SlowNode struct { + peer *Peer + cidr uint + bits []byte +} + +type SlowRouter []*SlowNode + +func (r SlowRouter) Len() int { + return len(r) +} + +func (r SlowRouter) Less(i, j int) bool { + return r[i].cidr > r[j].cidr +} + +func (r SlowRouter) Swap(i, j int) { + r[i], r[j] = r[j], r[i] +} + +func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { + for _, t := range r { + if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { + t.peer = peer + t.bits = addr + return r + } + } + r = append(r, &SlowNode{ + cidr: cidr, + bits: addr, + peer: peer, + }) + sort.Sort(r) + return r +} + +func (r SlowRouter) Lookup(addr []byte) *Peer { + for _, t := range r { + common := commonBits(t.bits, addr) + if common >= t.cidr { + return t.peer + } + } + return nil +} + +func TestTrieRandomIPv4(t *testing.T) { + var trie *Trie + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.Insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.Lookup(addr[:]) + if peer1 != peer2 { + t.Error("Trie did not match naive implementation, for:", addr) + } + } +} + +func TestTrieRandomIPv6(t *testing.T) { + var trie *Trie + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 16 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.Insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.Lookup(addr[:]) + if peer1 != peer2 { + t.Error("Trie did not match naive implementation, for:", addr) + } + } +} diff --git a/trie_test.go b/trie_test.go new file mode 100644 index 0000000..9d53df3 --- /dev/null +++ b/trie_test.go @@ -0,0 +1,255 @@ +package main + +import ( + "math/rand" + "net" + "testing" +) + +/* Todo: More comprehensive + */ + +type testPairCommonBits struct { + s1 []byte + s2 []byte + match uint +} + +type testPairTrieInsert struct { + key []byte + cidr uint + peer *Peer +} + +type testPairTrieLookup struct { + key []byte + peer *Peer +} + +func printTrie(t *testing.T, p *Trie) { + if p == nil { + return + } + t.Log(p) + printTrie(t, p.child[0]) + printTrie(t, p.child[1]) +} + +func TestCommonBits(t *testing.T) { + + tests := []testPairCommonBits{ + {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, + {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, + {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, + {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, + {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, + } + + for _, p := range tests { + v := commonBits(p.s1, p.s2) + if v != p.match { + t.Error( + "For slice", p.s1, p.s2, + "expected match", p.match, + ",but got", v, + ) + } + } +} + +func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { + var trie *Trie + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < peerNumber; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < addressNumber; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % peerNumber + trie = trie.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < b.N; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + trie.Lookup(addr[:]) + } +} + +func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv4len, b) +} + +func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv4len, b) +} + +func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv6len, b) +} + +func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv6len, b) +} + +/* Test ported from kernel implementation: + * selftest/routingtable.h + */ +func TestTrieIPv4(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *Trie + + insert := func(peer *Peer, a, b, c, d byte, cidr uint) { + trie = trie.Insert([]byte{a, b, c, d}, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.Lookup([]byte{a, b, c, d}) + if p != peer { + t.Error("Assert EQ failed") + } + } + + assertNEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.Lookup([]byte{a, b, c, d}) + if p == peer { + t.Error("Assert NEQ failed") + } + } + + insert(a, 192, 168, 4, 0, 24) + insert(b, 192, 168, 4, 4, 32) + insert(c, 192, 168, 0, 0, 16) + insert(d, 192, 95, 5, 64, 27) + insert(c, 192, 95, 5, 65, 27) + insert(e, 0, 0, 0, 0, 0) + insert(g, 64, 15, 112, 0, 20) + insert(h, 64, 15, 123, 211, 25) + insert(a, 10, 0, 0, 0, 25) + insert(b, 10, 0, 0, 128, 25) + insert(a, 10, 1, 0, 0, 30) + insert(b, 10, 1, 0, 4, 30) + insert(c, 10, 1, 0, 8, 29) + insert(d, 10, 1, 0, 16, 29) + + assertEQ(a, 192, 168, 4, 20) + assertEQ(a, 192, 168, 4, 0) + assertEQ(b, 192, 168, 4, 4) + assertEQ(c, 192, 168, 200, 182) + assertEQ(c, 192, 95, 5, 68) + assertEQ(e, 192, 95, 5, 96) + assertEQ(g, 64, 15, 116, 26) + assertEQ(g, 64, 15, 127, 3) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 64, 0, 0, 0, 32) + insert(a, 128, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 32) + insert(a, 255, 0, 0, 0, 32) + + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 64, 0, 0, 0) + assertEQ(a, 128, 0, 0, 0) + assertEQ(a, 192, 0, 0, 0) + assertEQ(a, 255, 0, 0, 0) + + trie = trie.RemovePeer(a) + + assertNEQ(a, 1, 0, 0, 0) + assertNEQ(a, 64, 0, 0, 0) + assertNEQ(a, 128, 0, 0, 0) + assertNEQ(a, 192, 0, 0, 0) + assertNEQ(a, 255, 0, 0, 0) + + trie = nil + + insert(a, 192, 168, 0, 0, 16) + insert(a, 192, 168, 0, 0, 24) + + trie = trie.RemovePeer(a) + + assertNEQ(a, 192, 168, 0, 1) +} + +/* Test ported from kernel implementation: + * selftest/routingtable.h + */ +func TestTrieIPv6(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + f := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *Trie + + expand := func(a uint32) []byte { + var out [4]byte + out[0] = byte(a >> 24 & 0xff) + out[1] = byte(a >> 16 & 0xff) + out[2] = byte(a >> 8 & 0xff) + out[3] = byte(a & 0xff) + return out[:] + } + + insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + trie = trie.Insert(addr, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := trie.Lookup(addr) + if p != peer { + t.Error("Assert EQ failed") + } + } + + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) + insert(c, 0x26075300, 0x60006b00, 0, 0, 64) + insert(e, 0, 0, 0, 0, 0) + insert(f, 0, 0, 0, 0, 0) + insert(g, 0x24046800, 0, 0, 0, 32) + insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) + insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) + insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + + assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) + assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) + assertEQ(f, 0x26075300, 0x60006b01, 0, 0) + assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) + assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0, 0) + assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) + assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) +} diff --git a/tun.go b/tun.go new file mode 100644 index 0000000..6259f33 --- /dev/null +++ b/tun.go @@ -0,0 +1,58 @@ +package main + +import ( + "os" + "sync/atomic" +) + +const DefaultMTU = 1420 + +type TUNEvent int + +const ( + TUNEventUp = 1 << iota + TUNEventDown + TUNEventMTUUpdate +) + +type TUNDevice interface { + File() *os.File // returns the file descriptor of the device + Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) + Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) + MTU() (int, error) // returns the MTU of the device + Name() string // returns the current name + Events() chan TUNEvent // returns a constant channel of events related to the device + Close() error // stops the device and closes the event channel +} + +func (device *Device) RoutineTUNEventReader() { + logInfo := device.log.Info + logError := device.log.Error + + for event := range device.tun.device.Events() { + if event&TUNEventMTUUpdate != 0 { + mtu, err := device.tun.device.MTU() + old := atomic.LoadInt32(&device.tun.mtu) + if err != nil { + logError.Println("Failed to load updated MTU of device:", err) + } else if int(old) != mtu { + if mtu+MessageTransportSize > MaxMessageSize { + logInfo.Println("MTU updated:", mtu, "(too large)") + } else { + logInfo.Println("MTU updated:", mtu) + } + atomic.StoreInt32(&device.tun.mtu, int32(mtu)) + } + } + + if event&TUNEventUp != 0 && !device.isUp.Get() { + logInfo.Println("Interface set up") + device.Up() + } + + if event&TUNEventDown != 0 && device.isUp.Get() { + logInfo.Println("Interface set down") + device.Down() + } + } +} diff --git a/tun_darwin.go b/tun_darwin.go new file mode 100644 index 0000000..87f6af6 --- /dev/null +++ b/tun_darwin.go @@ -0,0 +1,323 @@ +/* Copyright (c) 2016, Song Gao + * All rights reserved. + * + * Code from https://github.com/songgao/water + */ + +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + "io" + "net" + "os" + "sync" + "time" + "unsafe" +) + +const utunControlName = "com.apple.net.utun_control" + +// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h +const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3 + +// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h +type sockaddrCtl struct { + scLen uint8 + scFamily uint8 + ssSysaddr uint16 + scID uint32 + scUnit uint32 + scReserved [5]uint32 +} + +// NativeTUN is a hack to work around the first 4 bytes "packet +// information" because there doesn't seem to be an IFF_NO_PI for darwin. +type NativeTUN struct { + name string + f io.ReadWriteCloser + mtu int + + rMu sync.Mutex + rBuf []byte + + wMu sync.Mutex + wBuf []byte + + events chan TUNEvent + errors chan error +} + +var sockaddrCtlSize uintptr = 32 + +func CreateTUN(name string) (ifce TUNDevice, err error) { + ifIndex := -1 + fmt.Sscanf(name, "utun%d", &ifIndex) + if ifIndex < 0 { + return nil, fmt.Errorf("error parsing interface name %s, must be utun[0-9]+", name) + } + + fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) + + if err != nil { + return nil, fmt.Errorf("error in unix.Socket: %v", err) + } + + var ctlInfo = &struct { + ctlID uint32 + ctlName [96]byte + }{} + + copy(ctlInfo.ctlName[:], []byte(utunControlName)) + + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(_CTLIOCGINFO), + uintptr(unsafe.Pointer(ctlInfo)), + ) + + if errno != 0 { + err = errno + return nil, fmt.Errorf("error in unix.Syscall(unix.SYS_IOTL, ...): %v", err) + } + + sc := sockaddrCtl{ + scLen: uint8(sockaddrCtlSize), + scFamily: unix.AF_SYSTEM, + ssSysaddr: 2, + scID: ctlInfo.ctlID, + scUnit: uint32(ifIndex) + 1, + } + + scPointer := unsafe.Pointer(&sc) + + _, _, errno = unix.RawSyscall( + unix.SYS_CONNECT, + uintptr(fd), + uintptr(scPointer), + uintptr(sockaddrCtlSize), + ) + + if errno != 0 { + err = errno + return nil, fmt.Errorf("error in unix.RawSyscall(unix.SYS_CONNECT, ...): %v", err) + } + + // read (new) name of interface + + var ifName struct { + name [16]byte + } + ifNameSize := uintptr(16) + + _, _, errno = unix.Syscall6( + unix.SYS_GETSOCKOPT, + uintptr(fd), + 2, /* #define SYSPROTO_CONTROL 2 */ + 2, /* #define UTUN_OPT_IFNAME 2 */ + uintptr(unsafe.Pointer(&ifName)), + uintptr(unsafe.Pointer(&ifNameSize)), 0) + + if errno != 0 { + err = errno + return nil, fmt.Errorf("error in unix.Syscall6(unix.SYS_GETSOCKOPT, ...): %v", err) + } + + device := &NativeTUN{ + name: string(ifName.name[:ifNameSize-1 /* -1 is for \0 */]), + f: os.NewFile(uintptr(fd), string(ifName.name[:])), + mtu: 1500, + events: make(chan TUNEvent, 10), + errors: make(chan error, 1), + } + + // start listener + + go func(native *NativeTUN) { + // TODO: Fix this very niave implementation + var ( + statusUp bool + statusMTU int + ) + + for ; ; time.Sleep(time.Second) { + intr, err := net.InterfaceByName(device.name) + if err != nil { + native.errors <- err + return + } + + // Up / Down event + up := (intr.Flags & net.FlagUp) != 0 + if up != statusUp && up { + native.events <- TUNEventUp + } + if up != statusUp && !up { + native.events <- TUNEventDown + } + statusUp = up + + // MTU changes + if intr.MTU != statusMTU { + native.events <- TUNEventMTUUpdate + } + statusMTU = intr.MTU + } + }(device) + + // set default MTU + + err = device.setMTU(DefaultMTU) + + return device, err +} + +var _ io.ReadWriteCloser = (*NativeTUN)(nil) + +func (t *NativeTUN) Events() chan TUNEvent { + return t.events +} + +func (t *NativeTUN) Read(to []byte) (int, error) { + t.rMu.Lock() + defer t.rMu.Unlock() + + if cap(t.rBuf) < len(to)+4 { + t.rBuf = make([]byte, len(to)+4) + } + t.rBuf = t.rBuf[:len(to)+4] + + n, err := t.f.Read(t.rBuf) + copy(to, t.rBuf[4:]) + return n - 4, err +} + +func (t *NativeTUN) Write(from []byte) (int, error) { + + if len(from) == 0 { + return 0, unix.EIO + } + + t.wMu.Lock() + defer t.wMu.Unlock() + + if cap(t.wBuf) < len(from)+4 { + t.wBuf = make([]byte, len(from)+4) + } + t.wBuf = t.wBuf[:len(from)+4] + + // determine the IP Family for the NULL L2 Header + + ipVer := from[0] >> 4 + if ipVer == ipv4.Version { + t.wBuf[3] = unix.AF_INET + } else if ipVer == ipv6.Version { + t.wBuf[3] = unix.AF_INET6 + } else { + return 0, errors.New("Unable to determine IP version from packet.") + } + + copy(t.wBuf[4:], from) + + n, err := t.f.Write(t.wBuf) + return n - 4, err +} + +func (t *NativeTUN) Close() error { + + // lock to make sure no read/write is in process. + + t.rMu.Lock() + defer t.rMu.Unlock() + + t.wMu.Lock() + defer t.wMu.Unlock() + + return t.f.Close() +} + +func (t *NativeTUN) Name() string { + return t.name +} + +func (t *NativeTUN) setMTU(n int) error { + + // open datagram socket + + var fd int + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return err + } + + defer unix.Close(fd) + + // do ioctl call + + var ifr [32]byte + copy(ifr[:], t.name) + binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(unix.SIOCSIFMTU), + uintptr(unsafe.Pointer(&ifr[0])), + ) + + if errno != 0 { + return fmt.Errorf("Failed to set MTU on %s", t.name) + } + + return nil +} + +func (t *NativeTUN) MTU() (int, error) { + + // open datagram socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return 0, err + } + + defer unix.Close(fd) + + // do ioctl call + + var ifr [64]byte + copy(ifr[:], t.name) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(unix.SIOCGIFMTU), + uintptr(unsafe.Pointer(&ifr[0])), + ) + if errno != 0 { + return 0, fmt.Errorf("Failed to get MTU on %s", t.name) + } + + // convert result to signed 32-bit int + + val := binary.LittleEndian.Uint32(ifr[16:20]) + if val >= (1 << 31) { + return int(val-(1<<31)) - (1 << 31), nil + } + return int(val), nil +} diff --git a/tun_linux.go b/tun_linux.go new file mode 100644 index 0000000..daa2462 --- /dev/null +++ b/tun_linux.go @@ -0,0 +1,377 @@ +package main + +/* Implementation of the TUN device interface for linux + */ + +import ( + "encoding/binary" + "errors" + "fmt" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + "net" + "os" + "strings" + "time" + "unsafe" +) + +// #include +// #include +// #include +// #include +// #include +// #include +// +// /* Creates a netlink socket +// * listening to the RTMGRP_LINK multicast group +// */ +// +// int bind_rtmgrp() { +// int nl_sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); +// if (nl_sock < 0) +// return -1; +// +// struct sockaddr_nl addr; +// memset ((void *) &addr, 0, sizeof (addr)); +// addr.nl_family = AF_NETLINK; +// addr.nl_pid = getpid (); +// addr.nl_groups = RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR; +// +// if (bind(nl_sock, (struct sockaddr *) &addr, sizeof (addr)) < 0) +// return -1; +// +// return nl_sock; +// } +import "C" + +const ( + CloneDevicePath = "/dev/net/tun" + IFReqSize = unix.IFNAMSIZ + 64 +) + +type NativeTun struct { + fd *os.File + index int32 // if index + name string // name of interface + errors chan error // async error handling + events chan TUNEvent // device related events +} + +func (tun *NativeTun) File() *os.File { + return tun.fd +} + +func (tun *NativeTun) RoutineHackListener() { + /* This is needed for the detection to work across network namespaces + * If you are reading this and know a better method, please get in touch. + */ + fd := int(tun.fd.Fd()) + for { + _, err := unix.Write(fd, nil) + switch err { + case unix.EINVAL: + tun.events <- TUNEventUp + case unix.EIO: + tun.events <- TUNEventDown + default: + } + time.Sleep(time.Second / 10) + } +} + +func (tun *NativeTun) RoutineNetlinkListener() { + + sock := int(C.bind_rtmgrp()) + if sock < 0 { + tun.errors <- errors.New("Failed to create netlink event listener") + return + } + + for msg := make([]byte, 1<<16); ; { + + msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0) + if err != nil { + tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error()) + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if int(hdr.Len) > len(remain) { + break + } + + switch hdr.Type { + case unix.NLMSG_DONE: + remain = []byte{} + + case unix.RTM_NEWLINK: + info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr])) + remain = remain[hdr.Len:] + + if info.Index != tun.index { + // not our interface + continue + } + + if info.Flags&unix.IFF_RUNNING != 0 { + tun.events <- TUNEventUp + } + + if info.Flags&unix.IFF_RUNNING == 0 { + tun.events <- TUNEventDown + } + + tun.events <- TUNEventMTUUpdate + + default: + remain = remain[hdr.Len:] + } + } + } +} + +func (tun *NativeTun) isUp() (bool, error) { + inter, err := net.InterfaceByName(tun.name) + return inter.Flags&net.FlagUp != 0, err +} + +func (tun *NativeTun) Name() string { + return tun.name +} + +func getDummySock() (int, error) { + return unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) +} + +func getIFIndex(name string) (int32, error) { + fd, err := getDummySock() + if err != nil { + return 0, err + } + + defer unix.Close(fd) + + var ifr [IFReqSize]byte + copy(ifr[:], name) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(unix.SIOCGIFINDEX), + uintptr(unsafe.Pointer(&ifr[0])), + ) + + if errno != 0 { + return 0, errno + } + + index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) + return toInt32(index), nil +} + +func (tun *NativeTun) setMTU(n int) error { + + // open datagram socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return err + } + + defer unix.Close(fd) + + // do ioctl call + + var ifr [IFReqSize]byte + copy(ifr[:], tun.name) + binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(unix.SIOCSIFMTU), + uintptr(unsafe.Pointer(&ifr[0])), + ) + + if errno != 0 { + return errors.New("Failed to set MTU of TUN device") + } + + return nil +} + +func (tun *NativeTun) MTU() (int, error) { + + // open datagram socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return 0, err + } + + defer unix.Close(fd) + + // do ioctl call + + var ifr [IFReqSize]byte + copy(ifr[:], tun.name) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd), + uintptr(unix.SIOCGIFMTU), + uintptr(unsafe.Pointer(&ifr[0])), + ) + if errno != 0 { + return 0, errors.New("Failed to get MTU of TUN device") + } + + // convert result to signed 32-bit int + + val := binary.LittleEndian.Uint32(ifr[16:20]) + if val >= (1 << 31) { + return int(toInt32(val)), nil + } + return int(val), nil +} + +func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { + + // reserve space for header + + buff = buff[offset-4:] + + // add packet information header + + buff[0] = 0x00 + buff[1] = 0x00 + + if buff[4] == ipv6.Version<<4 { + buff[2] = 0x86 + buff[3] = 0xdd + } else { + buff[2] = 0x08 + buff[3] = 0x00 + } + + // write + + return tun.fd.Write(buff) +} + +func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { + select { + case err := <-tun.errors: + return 0, err + default: + buff := buff[offset-4:] + n, err := tun.fd.Read(buff[:]) + if n < 4 { + return 0, err + } + return n - 4, err + } +} + +func (tun *NativeTun) Events() chan TUNEvent { + return tun.events +} + +func (tun *NativeTun) Close() error { + return nil +} + +func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) { + device := &NativeTun{ + fd: fd, + name: name, + events: make(chan TUNEvent, 5), + errors: make(chan error, 5), + } + + // start event listener + + var err error + device.index, err = getIFIndex(device.name) + if err != nil { + return nil, err + } + + go device.RoutineNetlinkListener() + go device.RoutineHackListener() // cross namespace + + // set default MTU + + return device, device.setMTU(DefaultMTU) +} + +func CreateTUN(name string) (TUNDevice, error) { + + // open clone device + + fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + // create new device + + var ifr [IFReqSize]byte + var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI + nameBytes := []byte(name) + if len(nameBytes) >= unix.IFNAMSIZ { + return nil, errors.New("Interface name too long") + } + copy(ifr[:], nameBytes) + binary.LittleEndian.PutUint16(ifr[16:], flags) + + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + uintptr(fd.Fd()), + uintptr(unix.TUNSETIFF), + uintptr(unsafe.Pointer(&ifr[0])), + ) + if errno != 0 { + return nil, errno + } + + // read (new) name of interface + + newName := string(ifr[:]) + newName = newName[:strings.Index(newName, "\000")] + device := &NativeTun{ + fd: fd, + name: newName, + events: make(chan TUNEvent, 5), + errors: make(chan error, 5), + } + + // start event listener + + device.index, err = getIFIndex(device.name) + if err != nil { + return nil, err + } + + go device.RoutineNetlinkListener() + go device.RoutineHackListener() // cross namespace + + // set default MTU + + return device, device.setMTU(DefaultMTU) +} diff --git a/tun_windows.go b/tun_windows.go new file mode 100644 index 0000000..0711032 --- /dev/null +++ b/tun_windows.go @@ -0,0 +1,475 @@ +package main + +import ( + "encoding/binary" + "errors" + "fmt" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "net" + "sync" + "syscall" + "time" + "unsafe" +) + +/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version) + * + * https://github.com/OpenVPN/tap-windows + */ + +type NativeTUN struct { + fd windows.Handle + rl sync.Mutex + wl sync.Mutex + ro *windows.Overlapped + wo *windows.Overlapped + events chan TUNEvent + name string +} + +const ( + METHOD_BUFFERED = 0 + ComponentID = "tap0901" // tap0801 +) + +func ctl_code(device_type, function, method, access uint32) uint32 { + return (device_type << 16) | (access << 14) | (function << 2) | method +} + +func TAP_CONTROL_CODE(request, method uint32) uint32 { + return ctl_code(file_device_unknown, request, method, 0) +} + +var ( + errIfceNameNotFound = errors.New("Failed to find the name of interface") + + TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED) + TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED) + TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED) + TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED) + TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED) + TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED) + TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED) + + file_device_unknown = uint32(0x00000022) + nCreateEvent, + nResetEvent, + nGetOverlappedResult uintptr +) + +func init() { + k32, err := windows.LoadLibrary("kernel32.dll") + if err != nil { + panic("LoadLibrary " + err.Error()) + } + defer windows.FreeLibrary(k32) + nCreateEvent = getProcAddr(k32, "CreateEventW") + nResetEvent = getProcAddr(k32, "ResetEvent") + nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult") +} + +/* implementation of the read/write/closer interface */ + +func getProcAddr(lib windows.Handle, name string) uintptr { + addr, err := windows.GetProcAddress(lib, name) + if err != nil { + panic(name + " " + err.Error()) + } + return addr +} + +func resetEvent(h windows.Handle) error { + r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0) + if r == 0 { + return err + } + return nil +} + +func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) { + var n int + r, _, err := syscall.Syscall6( + nGetOverlappedResult, + 4, + uintptr(h), + uintptr(unsafe.Pointer(overlapped)), + uintptr(unsafe.Pointer(&n)), 1, 0, 0) + + if r == 0 { + return n, err + } + return n, nil +} + +func newOverlapped() (*windows.Overlapped, error) { + var overlapped windows.Overlapped + r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0) + if r == 0 { + return nil, err + } + overlapped.HEvent = windows.Handle(r) + return &overlapped, nil +} + +func (f *NativeTUN) Events() chan TUNEvent { + return f.events +} + +func (f *NativeTUN) Close() error { + return windows.Close(f.fd) +} + +func (f *NativeTUN) Write(b []byte) (int, error) { + f.wl.Lock() + defer f.wl.Unlock() + + if err := resetEvent(f.wo.HEvent); err != nil { + return 0, err + } + var n uint32 + err := windows.WriteFile(f.fd, b, &n, f.wo) + if err != nil && err != windows.ERROR_IO_PENDING { + return int(n), err + } + return getOverlappedResult(f.fd, f.wo) +} + +func (f *NativeTUN) Read(b []byte) (int, error) { + f.rl.Lock() + defer f.rl.Unlock() + + if err := resetEvent(f.ro.HEvent); err != nil { + return 0, err + } + var done uint32 + err := windows.ReadFile(f.fd, b, &done, f.ro) + if err != nil && err != windows.ERROR_IO_PENDING { + return int(done), err + } + return getOverlappedResult(f.fd, f.ro) +} + +func getdeviceid( + targetComponentId string, + targetDeviceName string, +) (deviceid string, err error) { + + getName := func(instanceId string) (string, error) { + path := fmt.Sprintf( + `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`, + instanceId, + ) + + key, err := registry.OpenKey( + registry.LOCAL_MACHINE, + path, + registry.READ, + ) + + if err != nil { + return "", err + } + defer key.Close() + + val, _, err := key.GetStringValue("Name") + key.Close() + return val, err + } + + getInstanceId := func(keyName string) (string, string, error) { + path := fmt.Sprintf( + `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`, + keyName, + ) + + key, err := registry.OpenKey( + registry.LOCAL_MACHINE, + path, + registry.READ, + ) + + if err != nil { + return "", "", err + } + defer key.Close() + + componentId, _, err := key.GetStringValue("ComponentId") + if err != nil { + return "", "", err + } + + instanceId, _, err := key.GetStringValue("NetCfgInstanceId") + + return componentId, instanceId, err + } + + // find list of all network devices + + k, err := registry.OpenKey( + registry.LOCAL_MACHINE, + `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`, + registry.READ, + ) + + if err != nil { + return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err) + } + + defer k.Close() + + keys, err := k.ReadSubKeyNames(-1) + + if err != nil { + return "", err + } + + // look for matching component id and name + + var componentFound bool + + for _, v := range keys { + + componentId, instanceId, err := getInstanceId(v) + if err != nil || componentId != targetComponentId { + continue + } + + componentFound = true + + deviceName, err := getName(instanceId) + if err != nil || deviceName != targetDeviceName { + continue + } + + return instanceId, nil + } + + // provide a descriptive error message + + if componentFound { + return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName) + } + + return "", fmt.Errorf( + "Unable to find device in registry with ComponentId = %s, is tap-windows installed?", + targetComponentId, + ) +} + +// setStatus is used to bring up or bring down the interface +func setStatus(fd windows.Handle, status bool) error { + var code [4]byte + if status { + binary.LittleEndian.PutUint32(code[:], 1) + } + + var bytesReturned uint32 + rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) + return windows.DeviceIoControl( + fd, + TAP_IOCTL_SET_MEDIA_STATUS, + &code[0], + uint32(4), + &rdbbuf[0], + uint32(len(rdbbuf)), + &bytesReturned, + nil, + ) +} + +/* When operating in TUN mode we must assign an ip address & subnet to the device. + * + */ +func setTUN(fd windows.Handle, network string) error { + var bytesReturned uint32 + rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) + localIP, remoteNet, err := net.ParseCIDR(network) + + if err != nil { + return fmt.Errorf("Failed to parse network CIDR in config, %v", err) + } + + if localIP.To4() == nil { + return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network) + } + + var param [12]byte + + copy(param[0:4], localIP.To4()) + copy(param[4:8], remoteNet.IP.To4()) + copy(param[8:12], remoteNet.Mask) + + return windows.DeviceIoControl( + fd, + TAP_IOCTL_CONFIG_TUN, + ¶m[0], + uint32(12), + &rdbbuf[0], + uint32(len(rdbbuf)), + &bytesReturned, + nil, + ) +} + +func (tun *NativeTUN) MTU() (int, error) { + var mtu [4]byte + var bytesReturned uint32 + err := windows.DeviceIoControl( + tun.fd, + TAP_IOCTL_GET_MTU, + &mtu[0], + uint32(len(mtu)), + &mtu[0], + uint32(len(mtu)), + &bytesReturned, + nil, + ) + val := binary.LittleEndian.Uint32(mtu[:]) + return int(val), err +} + +func (tun *NativeTUN) Name() string { + return tun.name +} + +func CreateTUN(name string) (TUNDevice, error) { + + // find the device in registry. + + deviceid, err := getdeviceid(ComponentID, name) + if err != nil { + return nil, err + } + path := "\\\\.\\Global\\" + deviceid + ".tap" + pathp, err := windows.UTF16PtrFromString(path) + if err != nil { + return nil, err + } + + // create TUN device + + handle, err := windows.CreateFile( + pathp, + windows.GENERIC_READ|windows.GENERIC_WRITE, + 0, + nil, + windows.OPEN_EXISTING, + windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED, + 0, + ) + + if err != nil { + return nil, err + } + + ro, err := newOverlapped() + if err != nil { + windows.Close(handle) + return nil, err + } + + wo, err := newOverlapped() + if err != nil { + windows.Close(handle) + return nil, err + } + + tun := &NativeTUN{ + fd: handle, + name: name, + ro: ro, + wo: wo, + events: make(chan TUNEvent, 5), + } + + // find addresses of interface + // TODO: fix this hack, the question is how + + inter, err := net.InterfaceByName(name) + if err != nil { + windows.Close(handle) + return nil, err + } + + addrs, err := inter.Addrs() + if err != nil { + windows.Close(handle) + return nil, err + } + + var ip net.IP + for _, addr := range addrs { + ip = func() net.IP { + switch v := addr.(type) { + case *net.IPNet: + return v.IP.To4() + case *net.IPAddr: + return v.IP.To4() + } + return nil + }() + if ip != nil { + break + } + } + + if ip == nil { + windows.Close(handle) + return nil, errors.New("No IPv4 address found for interface") + } + + // bring up device. + + if err := setStatus(handle, true); err != nil { + windows.Close(handle) + return nil, err + } + + // set tun mode + + mask := ip.String() + "/0" + if err := setTUN(handle, mask); err != nil { + windows.Close(handle) + return nil, err + } + + // start listener + + go func(native *NativeTUN, ifname string) { + // TODO: Fix this very niave implementation + var ( + statusUp bool + statusMTU int + ) + + for ; ; time.Sleep(time.Second) { + intr, err := net.InterfaceByName(name) + if err != nil { + // TODO: handle + return + } + + // Up / Down event + up := (intr.Flags & net.FlagUp) != 0 + if up != statusUp && up { + native.events <- TUNEventUp + } + if up != statusUp && !up { + native.events <- TUNEventDown + } + statusUp = up + + // MTU changes + if intr.MTU != statusMTU { + native.events <- TUNEventMTUUpdate + } + statusMTU = intr.MTU + } + }(tun, name) + + return tun, nil +} diff --git a/uapi.go b/uapi.go new file mode 100644 index 0000000..caaa498 --- /dev/null +++ b/uapi.go @@ -0,0 +1,437 @@ +package main + +import ( + "bufio" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" +) + +type IPCError struct { + Code int64 +} + +func (s *IPCError) Error() string { + return fmt.Sprintf("IPC error: %d", s.Code) +} + +func (s *IPCError) ErrorCode() int64 { + return s.Code +} + +func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { + + device.log.Debug.Println("UAPI: Processing get operation") + + // create lines + + lines := make([]string, 0, 100) + send := func(line string) { + lines = append(lines, line) + } + + func() { + + // lock required resources + + device.net.mutex.RLock() + defer device.net.mutex.RUnlock() + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + device.routing.mutex.RLock() + defer device.routing.mutex.RUnlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // serialize device related values + + if !device.noise.privateKey.IsZero() { + send("private_key=" + device.noise.privateKey.ToHex()) + } + + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) + } + + if device.net.fwmark != 0 { + send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + } + + // serialize each peer state + + for _, peer := range device.peers.keyMap { + peer.mutex.RLock() + defer peer.mutex.RUnlock() + + send("public_key=" + peer.handshake.remoteStatic.ToHex()) + send("preshared_key=" + peer.handshake.presharedKey.ToHex()) + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.DstToString()) + } + + nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() + + send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) + send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) + send(fmt.Sprintf("tx_bytes=%d", peer.stats.txBytes)) + send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) + send(fmt.Sprintf("persistent_keepalive_interval=%d", + atomic.LoadUint64(&peer.persistentKeepaliveInterval), + )) + + for _, ip := range device.routing.table.AllowedIPs(peer) { + send("allowed_ip=" + ip.String()) + } + + } + }() + + // send lines (does not require resource locks) + + for _, line := range lines { + _, err := socket.WriteString(line + "\n") + if err != nil { + return &IPCError{ + Code: ipcErrorIO, + } + } + } + + return nil +} + +func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { + scanner := bufio.NewScanner(socket) + logError := device.log.Error + logDebug := device.log.Debug + + var peer *Peer + + dummy := false + deviceConfig := true + + for scanner.Scan() { + + // parse line + + line := scanner.Text() + if line == "" { + return nil + } + parts := strings.Split(line, "=") + if len(parts) != 2 { + return &IPCError{Code: ipcErrorProtocol} + } + key := parts[0] + value := parts[1] + + /* device configuration */ + + if deviceConfig { + + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromHex(value) + if err != nil { + logError.Println("Failed to set private_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + logDebug.Println("UAPI: Updating device private key") + device.SetPrivateKey(sk) + + case "listen_port": + + // parse port number + + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to parse listen_port:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + // update port and rebind + + logDebug.Println("UAPI: Updating listen port") + + device.net.mutex.Lock() + device.net.port = uint16(port) + device.net.mutex.Unlock() + + if err := device.BindUpdate(); err != nil { + logError.Println("Failed to set listen_port:", err) + return &IPCError{Code: ipcErrorPortInUse} + } + + case "fwmark": + + // parse fwmark field + + fwmark, err := func() (uint32, error) { + if value == "" { + return 0, nil + } + mark, err := strconv.ParseUint(value, 10, 32) + return uint32(mark), err + }() + + if err != nil { + logError.Println("Invalid fwmark", err) + return &IPCError{Code: ipcErrorInvalid} + } + + logDebug.Println("UAPI: Updating fwmark") + + device.net.mutex.Lock() + device.net.fwmark = uint32(fwmark) + device.net.mutex.Unlock() + + if err := device.BindUpdate(); err != nil { + logError.Println("Failed to update fwmark:", err) + return &IPCError{Code: ipcErrorPortInUse} + } + + case "public_key": + // switch to peer configuration + logDebug.Println("UAPI: Transition to peer configuration") + deviceConfig = false + + case "replace_peers": + if value != "true" { + logError.Println("Failed to set replace_peers, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } + logDebug.Println("UAPI: Removing all peers") + device.RemoveAllPeers() + + default: + logError.Println("Invalid UAPI key (device configuration):", key) + return &IPCError{Code: ipcErrorInvalid} + } + } + + /* peer configuration */ + + if !deviceConfig { + + switch key { + + case "public_key": + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + logError.Println("Failed to get peer by public_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + // ignore peer with public key of device + + device.noise.mutex.RLock() + equals := device.noise.publicKey.Equals(publicKey) + device.noise.mutex.RUnlock() + + if equals { + peer = &Peer{} + dummy = true + } + + // find peer referenced + + peer = device.LookupPeer(publicKey) + + if peer == nil { + peer, err = device.NewPeer(publicKey) + if err != nil { + logError.Println("Failed to create new peer:", err) + return &IPCError{Code: ipcErrorInvalid} + } + logDebug.Println("UAPI: Created new peer:", peer.String()) + } + + peer.mutex.Lock() + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + peer.mutex.Unlock() + + case "remove": + + // remove currently selected peer from device + + if value != "true" { + logError.Println("Failed to set remove, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } + if !dummy { + logDebug.Println("UAPI: Removing peer:", peer.String()) + device.RemovePeer(peer.handshake.remoteStatic) + } + peer = &Peer{} + dummy = true + + case "preshared_key": + + // update PSK + + logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) + + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() + + if err != nil { + logError.Println("Failed to set preshared_key:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + case "endpoint": + + // set endpoint destination + + logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) + + err := func() error { + peer.mutex.Lock() + defer peer.mutex.Unlock() + endpoint, err := CreateEndpoint(value) + if err != nil { + return err + } + peer.endpoint = endpoint + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + return nil + }() + + if err != nil { + logError.Println("Failed to set endpoint:", value) + return &IPCError{Code: ipcErrorInvalid} + } + + case "persistent_keepalive_interval": + + // update keep-alive interval + + logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) + + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to set persistent_keepalive_interval:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + old := atomic.SwapUint64( + &peer.persistentKeepaliveInterval, + secs, + ) + + // send immediate keep-alive + + if old == 0 && secs != 0 { + if err != nil { + logError.Println("Failed to get tun device status:", err) + return &IPCError{Code: ipcErrorIO} + } + if device.isUp.Get() && !dummy { + peer.SendKeepAlive() + } + } + + case "replace_allowed_ips": + + logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) + + if value != "true" { + logError.Println("Failed to set replace_allowed_ips, invalid value:", value) + return &IPCError{Code: ipcErrorInvalid} + } + + if dummy { + continue + } + + device.routing.mutex.Lock() + device.routing.table.RemovePeer(peer) + device.routing.mutex.Unlock() + + case "allowed_ip": + + logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) + + _, network, err := net.ParseCIDR(value) + if err != nil { + logError.Println("Failed to set allowed_ip:", err) + return &IPCError{Code: ipcErrorInvalid} + } + + if dummy { + continue + } + + ones, _ := network.Mask.Size() + device.routing.mutex.Lock() + device.routing.table.Insert(network.IP, uint(ones), peer) + device.routing.mutex.Unlock() + + default: + logError.Println("Invalid UAPI key (peer configuration):", key) + return &IPCError{Code: ipcErrorInvalid} + } + } + } + + return nil +} + +func ipcHandle(device *Device, socket net.Conn) { + + // create buffered read/writer + + defer socket.Close() + + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) + + defer buffered.Flush() + + op, err := buffered.ReadString('\n') + if err != nil { + return + } + + // handle operation + + var status *IPCError + + switch op { + case "set=1\n": + device.log.Debug.Println("Config, set operation") + status = ipcSetOperation(device, buffered) + + case "get=1\n": + device.log.Debug.Println("Config, get operation") + status = ipcGetOperation(device, buffered) + + default: + device.log.Error.Println("Invalid UAPI operation:", op) + return + } + + // write status + + if status != nil { + device.log.Error.Println(status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } +} diff --git a/uapi_darwin.go b/uapi_darwin.go new file mode 100644 index 0000000..63d4d8d --- /dev/null +++ b/uapi_darwin.go @@ -0,0 +1,99 @@ +package main + +import ( + "fmt" + "golang.org/x/sys/unix" + "net" + "os" + "path" + "time" +) + +const ( + ipcErrorIO = -int64(unix.EIO) + ipcErrorProtocol = -int64(unix.EPROTO) + ipcErrorInvalid = -int64(unix.EINVAL) + ipcErrorPortInUse = -int64(unix.EADDRINUSE) + socketDirectory = "/var/run/wireguard" + socketName = "%s.sock" +) + +type UAPIListener struct { + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error +} + +func (l *UAPIListener) Accept() (net.Conn, error) { + for { + select { + case conn := <-l.connNew: + return conn, nil + + case err := <-l.connErr: + return nil, err + } + } +} + +func (l *UAPIListener) Close() error { + return l.listener.Close() +} + +func (l *UAPIListener) Addr() net.Addr { + return nil +} + +func NewUAPIListener(name string) (net.Listener, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 077) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + listener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, err + } + + uapi := &UAPIListener{ + listener: listener, + connNew: make(chan net.Conn, 1), + connErr: make(chan error, 1), + } + + // watch for deletion of socket + + go func(l *UAPIListener) { + for ; ; time.Sleep(time.Second) { + if _, err := os.Stat(socketPath); os.IsNotExist(err) { + l.connErr <- err + return + } + } + }(uapi) + + // watch for new connections + + go func(l *UAPIListener) { + for { + conn, err := l.listener.Accept() + if err != nil { + l.connErr <- err + break + } + l.connNew <- conn + } + }(uapi) + + return uapi, nil +} diff --git a/uapi_linux.go b/uapi_linux.go new file mode 100644 index 0000000..f97a18a --- /dev/null +++ b/uapi_linux.go @@ -0,0 +1,171 @@ +package main + +import ( + "errors" + "fmt" + "golang.org/x/sys/unix" + "net" + "os" + "path" +) + +const ( + ipcErrorIO = -int64(unix.EIO) + ipcErrorProtocol = -int64(unix.EPROTO) + ipcErrorInvalid = -int64(unix.EINVAL) + ipcErrorPortInUse = -int64(unix.EADDRINUSE) + socketDirectory = "/var/run/wireguard" + socketName = "%s.sock" +) + +type UAPIListener struct { + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error + inotifyFd int +} + +func (l *UAPIListener) Accept() (net.Conn, error) { + for { + select { + case conn := <-l.connNew: + return conn, nil + + case err := <-l.connErr: + return nil, err + } + } +} + +func (l *UAPIListener) Close() error { + err1 := unix.Close(l.inotifyFd) + err2 := l.listener.Close() + if err1 != nil { + return err1 + } + return err2 +} + +func (l *UAPIListener) Addr() net.Addr { + return nil +} + +func UAPIListen(name string, file *os.File) (net.Listener, error) { + + // wrap file in listener + + listener, err := net.FileListener(file) + if err != nil { + return nil, err + } + + uapi := &UAPIListener{ + listener: listener, + connNew: make(chan net.Conn, 1), + connErr: make(chan error, 1), + } + + // watch for deletion of socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + uapi.inotifyFd, err = unix.InotifyInit() + if err != nil { + return nil, err + } + + _, err = unix.InotifyAddWatch( + uapi.inotifyFd, + socketPath, + unix.IN_ATTRIB| + unix.IN_DELETE| + unix.IN_DELETE_SELF, + ) + + if err != nil { + return nil, err + } + + go func(l *UAPIListener) { + var buff [4096]byte + for { + // start with lstat to avoid race condition + if _, err := os.Lstat(socketPath); os.IsNotExist(err) { + l.connErr <- err + return + } + unix.Read(uapi.inotifyFd, buff[:]) + } + }(uapi) + + // watch for new connections + + go func(l *UAPIListener) { + for { + conn, err := l.listener.Accept() + if err != nil { + l.connErr <- err + break + } + l.connNew <- conn + } + }(uapi) + + return uapi, nil +} + +func UAPIOpen(name string) (*os.File, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 0600) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + listener, err := func() (*net.UnixListener, error) { + + // initial connection attempt + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener, nil + } + + // check if socket already active + + _, err = net.Dial("unix", socketPath) + if err == nil { + return nil, errors.New("unix socket in use") + } + + // cleanup & attempt again + + err = os.Remove(socketPath) + if err != nil { + return nil, err + } + return net.ListenUnix("unix", addr) + }() + + if err != nil { + return nil, err + } + + return listener.File() +} diff --git a/uapi_windows.go b/uapi_windows.go new file mode 100644 index 0000000..a4599a5 --- /dev/null +++ b/uapi_windows.go @@ -0,0 +1,44 @@ +package main + +/* UAPI on windows uses a bidirectional named pipe + */ + +import ( + "fmt" + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" + "net" +) + +const ( + ipcErrorIO = -int64(windows.ERROR_BROKEN_PIPE) + ipcErrorProtocol = -int64(windows.ERROR_INVALID_NAME) + ipcErrorInvalid = -int64(windows.ERROR_INVALID_PARAMETER) + ipcErrorPortInUse = -int64(windows.ERROR_ALREADY_EXISTS) +) + +const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s" + +type UAPIListener struct { + listener net.Listener +} + +func (uapi *UAPIListener) Accept() (net.Conn, error) { + return nil, nil +} + +func (uapi *UAPIListener) Close() error { + return uapi.listener.Close() +} + +func (uapi *UAPIListener) Addr() net.Addr { + return nil +} + +func NewUAPIListener(name string) (net.Listener, error) { + path := fmt.Sprintf(PipeNameFmt, name) + return winio.ListenPipe(path, &winio.PipeConfig{ + InputBufferSize: 2048, + OutputBufferSize: 2048, + }) +} diff --git a/xchacha20.go b/xchacha20.go new file mode 100644 index 0000000..5d963e0 --- /dev/null +++ b/xchacha20.go @@ -0,0 +1,169 @@ +// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. +// Use of this source code is governed by a license that can be +// found in the LICENSE file. + +package main + +import ( + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" +) + +func HChaCha20(out *[32]byte, nonce []byte, key *[32]byte) { + + v00 := uint32(0x61707865) + v01 := uint32(0x3320646e) + v02 := uint32(0x79622d32) + v03 := uint32(0x6b206574) + + v04 := binary.LittleEndian.Uint32(key[0:]) + v05 := binary.LittleEndian.Uint32(key[4:]) + v06 := binary.LittleEndian.Uint32(key[8:]) + v07 := binary.LittleEndian.Uint32(key[12:]) + v08 := binary.LittleEndian.Uint32(key[16:]) + v09 := binary.LittleEndian.Uint32(key[20:]) + v10 := binary.LittleEndian.Uint32(key[24:]) + v11 := binary.LittleEndian.Uint32(key[28:]) + v12 := binary.LittleEndian.Uint32(nonce[0:]) + v13 := binary.LittleEndian.Uint32(nonce[4:]) + v14 := binary.LittleEndian.Uint32(nonce[8:]) + v15 := binary.LittleEndian.Uint32(nonce[12:]) + + for i := 0; i < 20; i += 2 { + v00 += v04 + v12 ^= v00 + v12 = (v12 << 16) | (v12 >> 16) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 12) | (v04 >> 20) + v00 += v04 + v12 ^= v00 + v12 = (v12 << 8) | (v12 >> 24) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 7) | (v04 >> 25) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 16) | (v13 >> 16) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 12) | (v05 >> 20) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 8) | (v13 >> 24) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 7) | (v05 >> 25) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 16) | (v14 >> 16) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 12) | (v06 >> 20) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 8) | (v14 >> 24) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 7) | (v06 >> 25) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 16) | (v15 >> 16) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 12) | (v07 >> 20) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 8) | (v15 >> 24) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 7) | (v07 >> 25) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 16) | (v15 >> 16) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 12) | (v05 >> 20) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 8) | (v15 >> 24) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 7) | (v05 >> 25) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 16) | (v12 >> 16) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 12) | (v06 >> 20) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 8) | (v12 >> 24) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 7) | (v06 >> 25) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 16) | (v13 >> 16) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 12) | (v07 >> 20) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 8) | (v13 >> 24) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 7) | (v07 >> 25) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 16) | (v14 >> 16) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 12) | (v04 >> 20) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 8) | (v14 >> 24) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 7) | (v04 >> 25) + } + + binary.LittleEndian.PutUint32(out[0:], v00) + binary.LittleEndian.PutUint32(out[4:], v01) + binary.LittleEndian.PutUint32(out[8:], v02) + binary.LittleEndian.PutUint32(out[12:], v03) + binary.LittleEndian.PutUint32(out[16:], v12) + binary.LittleEndian.PutUint32(out[20:], v13) + binary.LittleEndian.PutUint32(out[24:], v14) + binary.LittleEndian.PutUint32(out[28:], v15) +} + +func XChaCha20Poly1305Encrypt( + dst []byte, + nonceFull *[24]byte, + plaintext []byte, + additionalData []byte, + key *[chacha20poly1305.KeySize]byte, +) []byte { + var nonce [chacha20poly1305.NonceSize]byte + var derivedKey [chacha20poly1305.KeySize]byte + HChaCha20(&derivedKey, nonceFull[:16], key) + aead, _ := chacha20poly1305.New(derivedKey[:]) + copy(nonce[4:], nonceFull[16:]) + return aead.Seal(dst, nonce[:], plaintext, additionalData) +} + +func XChaCha20Poly1305Decrypt( + dst []byte, + nonceFull *[24]byte, + plaintext []byte, + additionalData []byte, + key *[chacha20poly1305.KeySize]byte, +) ([]byte, error) { + var nonce [chacha20poly1305.NonceSize]byte + var derivedKey [chacha20poly1305.KeySize]byte + HChaCha20(&derivedKey, nonceFull[:16], key) + aead, _ := chacha20poly1305.New(derivedKey[:]) + copy(nonce[4:], nonceFull[16:]) + return aead.Open(dst, nonce[:], plaintext, additionalData) +} diff --git a/xchacha20_test.go b/xchacha20_test.go new file mode 100644 index 0000000..0f41cf8 --- /dev/null +++ b/xchacha20_test.go @@ -0,0 +1,96 @@ +package main + +import ( + "encoding/hex" + "testing" +) + +type XChaCha20Test struct { + Nonce string + Key string + PT string + CT string +} + +func TestXChaCha20(t *testing.T) { + + tests := []XChaCha20Test{ + { + Nonce: "000000000000000000000000000000000000000000000000", + Key: "0000000000000000000000000000000000000000000000000000000000000000", + PT: "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + CT: "789e9689e5208d7fd9e1f3c5b5341f48ef18a13e418998addadd97a3693a987f8e82ecd5c1433bfed1af49750c0f1ff29c4174a05b119aa3a9e8333812e0c0feb1299c5949d895ee01dbf50f8395dd84", + }, + { + Nonce: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", + Key: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", + PT: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", + CT: "e1a046aa7f71e2af8b80b6408b2fd8d3a350278cde79c94d9efaa475e1339b3dd490127b", + }, + { + Nonce: "d9a8213e8a697508805c2c171ad54487ead9e3e02d82d5bc", + Key: "979196dbd78526f2f584f7534db3f5824d8ccfa858ca7e09bdd3656ecd36033c", + PT: "43cc6d624e451bbed952c3e071dc6c03392ce11eb14316a94b2fdc98b22fedea", + CT: "53c1e8bef2dbb8f2505ec010a7afe21d5a8e6dd8f987e4ea1a2ed5dfbc844ea400db34496fd2153526c6e87c36694200", + }, + } + + for _, test := range tests { + + nonce, err := hex.DecodeString(test.Nonce) + if err != nil { + panic(err) + } + + key, err := hex.DecodeString(test.Key) + if err != nil { + panic(err) + } + + pt, err := hex.DecodeString(test.PT) + if err != nil { + panic(err) + } + + func() { + var nonceArray [24]byte + var keyArray [32]byte + copy(nonceArray[:], nonce) + copy(keyArray[:], key) + + // test encryption + + ct := XChaCha20Poly1305Encrypt( + nil, + &nonceArray, + pt, + nil, + &keyArray, + ) + ctHex := hex.EncodeToString(ct) + if ctHex != test.CT { + t.Fatal("encryption failed, expected:", test.CT, "got", ctHex) + } + + // test decryption + + ptp, err := XChaCha20Poly1305Decrypt( + nil, + &nonceArray, + ct, + nil, + &keyArray, + ) + if err != nil { + t.Fatal(err) + } + + ptHex := hex.EncodeToString(ptp) + if ptHex != test.PT { + t.Fatal("decryption failed, expected:", test.PT, "got", ptHex) + } + }() + + } + +} -- cgit v1.2.3