diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2018-02-04 16:08:26 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2018-02-04 16:08:26 +0100 |
commit | a0f54cbe5ac2cd8b8296c2c57c30029dd349cff0 (patch) | |
tree | 64574090d79ff3899c5c18e5268e450028e4656b /src | |
parent | 5871ec04deb8f4715cab37146940baa35c08cbee (diff) |
Align with go library layout
Diffstat (limited to 'src')
50 files changed, 0 insertions, 8845 deletions
diff --git a/src/Makefile b/src/Makefile deleted file mode 100644 index 5b23ecc..0000000 --- a/src/Makefile +++ /dev/null @@ -1,12 +0,0 @@ -all: wireguard-go - -wireguard-go: $(wildcard *.go) - go build -o $@ - -clean: - rm -f wireguard-go - -cloc: - cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go)) - -.PHONY: clean cloc diff --git a/src/build.cmd b/src/build.cmd deleted file mode 100755 index 52cb883..0000000 --- a/src/build.cmd +++ /dev/null @@ -1,6 +0,0 @@ -@echo off - -REM builds wireguard for windows - -go get -go build -o wireguard-go.exe diff --git a/src/conn.go b/src/conn.go deleted file mode 100644 index fb30ec2..0000000 --- a/src/conn.go +++ /dev/null @@ -1,128 +0,0 @@ -package main - -import ( - "errors" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - return addr, err -} - -/* Must hold device and net lock - */ -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - return err -} - -func (device *Device) BindUpdate() error { - - device.net.mutex.Lock() - defer device.net.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port) - if err != nil { - netc.bind = nil - return err - } - - // set mark - - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - - // clear cached source addresses - - for _, peer := range device.peers.keyMap { - peer.mutex.Lock() - defer peer.mutex.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - - // start receiving routines - - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.mutex.Lock() - err := unsafeCloseBind(device) - device.net.mutex.Unlock() - return err -} diff --git a/src/conn_default.go b/src/conn_default.go deleted file mode 100644 index 5b73c90..0000000 --- a/src/conn_default.go +++ /dev/null @@ -1,131 +0,0 @@ -// +build !linux - -package main - -import ( - "net" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type NativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*NativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // retrieve port - - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func CreateBind(uport uint16) (Bind, uint16, error) { - var err error - var bind NativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil { - bind.ipv4.Close() - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *NativeBind) Close() error { - err1 := bind.ipv4.Close() - err2 := bind.ipv6.Close() - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To16() != nil { - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} - -func (bind *NativeBind) SetMark(_ uint32) error { - return nil -} diff --git a/src/conn_linux.go b/src/conn_linux.go deleted file mode 100644 index cdba74f..0000000 --- a/src/conn_linux.go +++ /dev/null @@ -1,582 +0,0 @@ -/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. - */ - -package main - -import ( - "encoding/binary" - "errors" - "golang.org/x/sys/unix" - "net" - "strconv" - "unsafe" -) - -/* Supports source address caching - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. - */ -type NativeEndpoint struct { - src unix.RawSockaddrInet6 - dst unix.RawSockaddrInet6 -} - -type NativeBind struct { - sock4 int - sock6 int -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = NativeBind{} - -type IPv4Source struct { - src unix.RawSockaddrInet4 - Ifindex int32 -} - -func htons(val uint16) uint16 { - var out [unsafe.Sizeof(val)]byte - binary.BigEndian.PutUint16(out[:], val) - return *((*uint16)(unsafe.Pointer(&out[0]))) -} - -func ntohs(val uint16) uint16 { - tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val))) - return binary.BigEndian.Uint16((*tmp)[:]) -} - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) - dst.Family = unix.AF_INET - dst.Port = htons(uint16(addr.Port)) - dst.Zero = [8]byte{} - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := &end.dst - dst.Family = unix.AF_INET6 - dst.Port = htons(uint16(addr.Port)) - dst.Flowinfo = 0 - dst.Scope_id = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Failed to recognize IP address format") -} - -func CreateBind(port uint16) (Bind, uint16, error) { - var err error - var bind NativeBind - - bind.sock6, port, err = create6(port) - if err != nil { - return nil, port, err - } - - bind.sock4, port, err = create4(port) - if err != nil { - unix.Close(bind.sock6) - } - return bind, port, err -} - -func (bind NativeBind) SetMark(value uint32) error { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - - return unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers - unix.Shutdown(fd, unix.SHUT_RD) - return unix.Close(fd) -} - -func (bind NativeBind) Close() error { - err1 := closeUnblock(bind.sock6) - err2 := closeUnblock(bind.sock4) - if err1 != nil { - return err1 - } - return err2 -} - -func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind NativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - switch nend.dst.Family { - case unix.AF_INET6: - return send6(bind.sock6, nend, buff) - case unix.AF_INET: - return send4(bind.sock4, nend, buff) - default: - return errors.New("Unknown address family of destination") - } -} - -func sockaddrToString(addr unix.RawSockaddrInet6) string { - var udpAddr net.UDPAddr - - switch addr.Family { - case unix.AF_INET6: - udpAddr.Port = int(ntohs(addr.Port)) - udpAddr.IP = addr.Addr[:] - return udpAddr.String() - - case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - udpAddr.Port = int(ntohs(ptr.Port)) - udpAddr.IP = net.IPv4( - ptr.Addr[0], - ptr.Addr[1], - ptr.Addr[2], - ptr.Addr[3], - ) - return udpAddr.String() - - default: - return "<unknown address family>" - } -} - -func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { - switch addr.Family { - case unix.AF_INET6: - return addr.Addr[:] - case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) - return net.IPv4( - ptr.Addr[0], - ptr.Addr[1], - ptr.Addr[2], - ptr.Addr[3], - ) - default: - return nil - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - return rawAddrToIP(end.src) -} - -func (end *NativeEndpoint) DstIP() net.IP { - return rawAddrToIP(end.dst) -} - -func (end *NativeEndpoint) DstToBytes() []byte { - ptr := unsafe.Pointer(&end.src) - arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) - return arr[:] -} - -func (end *NativeEndpoint) SrcToString() string { - return sockaddrToString(end.src) -} - -func (end *NativeEndpoint) DstToString() string { - return sockaddrToString(end.dst) -} - -func (end *NativeEndpoint) ClearDst() { - end.dst = unix.RawSockaddrInet6{} -} - -func (end *NativeEndpoint) ClearSrc() { - end.src = unix.RawSockaddrInet6{} -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return -1, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return -1, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - } - - return fd, uint16(addr.Port), err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src.Addr, - Ifindex: end.src.Scope_id, - }, - } - - msghdr := unix.Msghdr{ - Iov: &iovec, - Iovlen: 1, - Name: (*byte)(unsafe.Pointer(&end.dst)), - Namelen: unix.SizeofSockaddrInet6, - Control: (*byte)(unsafe.Pointer(&cmsg)), - } - - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - // sendmsg(sock, &msghdr, 0) - - _, _, errno := unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - - if errno == 0 { - return nil - } - - // clear src and retry - - if errno == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - _, _, errno = unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - } - - return errno -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: src4.src.Addr, - Ifindex: src4.Ifindex, - }, - } - - msghdr := unix.Msghdr{ - Iov: &iovec, - Iovlen: 1, - Name: (*byte)(unsafe.Pointer(&end.dst)), - Namelen: unix.SizeofSockaddrInet4, - Control: (*byte)(unsafe.Pointer(&cmsg)), - Flags: 0, - } - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - // sendmsg(sock, &msghdr, 0) - - _, _, errno := unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - - // clear source and try again - - if errno == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - _, _, errno = unix.Syscall( - unix.SYS_SENDMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - } - - // errno = 0 is still an error instance - - if errno == 0 { - return nil - } - - return errno -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - var msghdr unix.Msghdr - msghdr.Iov = &iovec - msghdr.Iovlen = 1 - msghdr.Name = (*byte)(unsafe.Pointer(&end.dst)) - msghdr.Namelen = unix.SizeofSockaddrInet4 - msghdr.Control = (*byte)(unsafe.Pointer(&cmsg)) - msghdr.SetControllen(int(unsafe.Sizeof(cmsg))) - - // recvmsg(sock, &mskhdr, 0) - - size, _, errno := unix.Syscall( - unix.SYS_RECVMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msghdr)), - 0, - ) - - if errno != 0 { - return 0, errno - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) - src4.src.Family = unix.AF_INET - src4.src.Addr = cmsg.pktinfo.Spec_dst - src4.Ifindex = cmsg.pktinfo.Ifindex - } - - return int(size), nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var iovec unix.Iovec - iovec.Base = (*byte)(unsafe.Pointer(&buff[0])) - iovec.SetLen(len(buff)) - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - var msg unix.Msghdr - msg.Iov = &iovec - msg.Iovlen = 1 - msg.Name = (*byte)(unsafe.Pointer(&end.dst)) - msg.Namelen = uint32(unix.SizeofSockaddrInet6) - msg.Control = (*byte)(unsafe.Pointer(&cmsg)) - msg.SetControllen(int(unsafe.Sizeof(cmsg))) - - // recvmsg(sock, &mskhdr, 0) - - size, _, errno := unix.Syscall( - unix.SYS_RECVMSG, - uintptr(sock), - uintptr(unsafe.Pointer(&msg)), - 0, - ) - - if errno != 0 { - return 0, errno - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src.Family = unix.AF_INET6 - end.src.Addr = cmsg.pktinfo.Addr - end.src.Scope_id = cmsg.pktinfo.Ifindex - } - - return int(size), nil -} diff --git a/src/constants.go b/src/constants.go deleted file mode 100644 index 71dd98e..0000000 --- a/src/constants.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "time" -) - -/* Specification constants */ - -const ( - RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 - RejectAfterTime = time.Second * 180 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 120 - HandshakeInitationRate = time.Second / 20 - PaddingMultiple = 16 -) - -const ( - RekeyAfterTimeReceiving = RekeyAfterTime - KeepaliveTimeout - RekeyTimeout - NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message -) - -/* Implementation specific constants */ - -const ( - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content -) - -const ( - UnderLoadQueueSize = QueueHandshakeSize / 8 - UnderLoadAfterTime = time.Second // how long does the device remain under load after detected - MaxPeers = 1 << 16 // maximum number of configured peers -) diff --git a/src/cookie.go b/src/cookie.go deleted file mode 100644 index a13ad49..0000000 --- a/src/cookie.go +++ /dev/null @@ -1,252 +0,0 @@ -package main - -import ( - "crypto/hmac" - "crypto/rand" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "sync" - "time" -) - -type CookieChecker struct { - mutex sync.RWMutex - mac1 struct { - key [blake2s.Size]byte - } - mac2 struct { - secret [blake2s.Size]byte - secretSet time.Time - encryptionKey [chacha20poly1305.KeySize]byte - } -} - -type CookieGenerator struct { - mutex sync.RWMutex - mac1 struct { - key [blake2s.Size]byte - } - mac2 struct { - cookie [blake2s.Size128]byte - cookieSet time.Time - hasLastMAC1 bool - lastMAC1 [blake2s.Size128]byte - encryptionKey [chacha20poly1305.KeySize]byte - } -} - -func (st *CookieChecker) Init(pk NoisePublicKey) { - st.mutex.Lock() - defer st.mutex.Unlock() - - // mac1 state - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelMAC1)) - hsh.Write(pk[:]) - hsh.Sum(st.mac1.key[:0]) - }() - - // mac2 state - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelCookie)) - hsh.Write(pk[:]) - hsh.Sum(st.mac2.encryptionKey[:0]) - }() - - st.mac2.secretSet = time.Time{} -} - -func (st *CookieChecker) CheckMAC1(msg []byte) bool { - size := len(msg) - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - var mac1 [blake2s.Size128]byte - - mac, _ := blake2s.New128(st.mac1.key[:]) - mac.Write(msg[:smac1]) - mac.Sum(mac1[:0]) - - return hmac.Equal(mac1[:], msg[smac1:smac2]) -} - -func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { - st.mutex.RLock() - defer st.mutex.RUnlock() - - if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { - return false - } - - // derive cookie key - - var cookie [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src) - mac.Sum(cookie[:0]) - }() - - // calculate mac of packet (including mac1) - - smac2 := len(msg) - blake2s.Size128 - - var mac2 [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(cookie[:]) - mac.Write(msg[:smac2]) - mac.Sum(mac2[:0]) - }() - - return hmac.Equal(mac2[:], msg[smac2:]) -} - -func (st *CookieChecker) CreateReply( - msg []byte, - recv uint32, - src []byte, -) (*MessageCookieReply, error) { - - st.mutex.RLock() - - // refresh cookie secret - - if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { - st.mutex.RUnlock() - st.mutex.Lock() - _, err := rand.Read(st.mac2.secret[:]) - if err != nil { - st.mutex.Unlock() - return nil, err - } - st.mac2.secretSet = time.Now() - st.mutex.Unlock() - st.mutex.RLock() - } - - // derive cookie - - var cookie [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src) - mac.Sum(cookie[:0]) - }() - - // encrypt cookie - - size := len(msg) - - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - reply := new(MessageCookieReply) - reply.Type = MessageCookieReplyType - reply.Receiver = recv - - _, err := rand.Read(reply.Nonce[:]) - if err != nil { - st.mutex.RUnlock() - return nil, err - } - - XChaCha20Poly1305Encrypt( - reply.Cookie[:0], - &reply.Nonce, - cookie[:], - msg[smac1:smac2], - &st.mac2.encryptionKey, - ) - - st.mutex.RUnlock() - - return reply, nil -} - -func (st *CookieGenerator) Init(pk NoisePublicKey) { - st.mutex.Lock() - defer st.mutex.Unlock() - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelMAC1)) - hsh.Write(pk[:]) - hsh.Sum(st.mac1.key[:0]) - }() - - func() { - hsh, _ := blake2s.New256(nil) - hsh.Write([]byte(WGLabelCookie)) - hsh.Write(pk[:]) - hsh.Sum(st.mac2.encryptionKey[:0]) - }() - - st.mac2.cookieSet = time.Time{} -} - -func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { - st.mutex.Lock() - defer st.mutex.Unlock() - - if !st.mac2.hasLastMAC1 { - return false - } - - var cookie [blake2s.Size128]byte - - _, err := XChaCha20Poly1305Decrypt( - cookie[:0], - &msg.Nonce, - msg.Cookie[:], - st.mac2.lastMAC1[:], - &st.mac2.encryptionKey, - ) - - if err != nil { - return false - } - - st.mac2.cookieSet = time.Now() - st.mac2.cookie = cookie - return true -} - -func (st *CookieGenerator) AddMacs(msg []byte) { - - size := len(msg) - - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - mac1 := msg[smac1:smac2] - mac2 := msg[smac2:] - - st.mutex.Lock() - defer st.mutex.Unlock() - - // set mac1 - - func() { - mac, _ := blake2s.New128(st.mac1.key[:]) - mac.Write(msg[:smac1]) - mac.Sum(mac1[:0]) - }() - copy(st.mac2.lastMAC1[:], mac1) - st.mac2.hasLastMAC1 = true - - // set mac2 - - if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { - return - } - - func() { - mac, _ := blake2s.New128(st.mac2.cookie[:]) - mac.Write(msg[:smac2]) - mac.Sum(mac2[:0]) - }() -} diff --git a/src/cookie_test.go b/src/cookie_test.go deleted file mode 100644 index d745fe7..0000000 --- a/src/cookie_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package main - -import ( - "testing" -) - -func TestCookieMAC1(t *testing.T) { - - // setup generator / checker - - var ( - generator CookieGenerator - checker CookieChecker - ) - - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - pk := sk.publicKey() - - generator.Init(pk) - checker.Init(pk) - - // check mac1 - - src := []byte{192, 168, 13, 37, 10, 10, 10} - - checkMAC1 := func(msg []byte) { - generator.AddMacs(msg) - if !checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - } - - checkMAC1([]byte{ - 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, - 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, - 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, - 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, - 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, - 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, - 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, - 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, - 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, - 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, - 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, - 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, - 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, - 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, - 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, - 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, - }) - - checkMAC1([]byte{ - 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, - 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, - 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, - 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, - 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, - 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, - 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, - 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, - }) - - checkMAC1([]byte{ - 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, - 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, - 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, - 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, - }) - - // exchange cookie reply - - func() { - msg := []byte{ - 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, - 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, - 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, - 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, - 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, - 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, - 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, - 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, - 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, - 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, - 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, - 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, - 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, - 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, - 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, - 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, - } - generator.AddMacs(msg) - reply, err := checker.CreateReply(msg, 1377, src) - if err != nil { - t.Fatal("Failed to create cookie reply:", err) - } - if !generator.ConsumeReply(reply) { - t.Fatal("Failed to consume cookie reply") - } - }() - - // check mac2 - - checkMAC2 := func(msg []byte) { - generator.AddMacs(msg) - - if !checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if !checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - - msg[5] ^= 0x20 - - if checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - - msg[5] ^= 0x20 - - srcBad1 := []byte{192, 168, 13, 37, 40, 01} - if checker.CheckMAC2(msg, srcBad1) { - t.Fatal("MAC2 generation/verification failed") - } - - srcBad2 := []byte{192, 168, 13, 38, 40, 01} - if checker.CheckMAC2(msg, srcBad2) { - t.Fatal("MAC2 generation/verification failed") - } - } - - checkMAC2([]byte{ - 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, - 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, - 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, - 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, - 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, - 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, - 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, - 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, - }) - - checkMAC2([]byte{ - 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, - 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, - 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, - 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, - 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, - 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, - 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, - 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, - 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, - 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, - 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, - 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, - 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, - 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, - 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, - 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, - 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, - 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, - 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, - 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, - 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, - 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, - 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, - 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, - 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, - 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, - 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, - 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, - 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, - 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, - 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, - 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, - }) -} diff --git a/src/daemon_darwin.go b/src/daemon_darwin.go deleted file mode 100644 index 913af0e..0000000 --- a/src/daemon_darwin.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -import ( - "errors" -) - -func Daemonize() error { - return errors.New("Not implemented on OSX") -} diff --git a/src/daemon_linux.go b/src/daemon_linux.go deleted file mode 100644 index e1aaede..0000000 --- a/src/daemon_linux.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "os" - "os/exec" -) - -/* Daemonizes the process on linux - * - * This is done by spawning and releasing a copy with the --foreground flag - */ -func Daemonize(attr *os.ProcAttr) error { - // I would like to use os.Executable, - // however this means dropping support for Go <1.8 - path, err := exec.LookPath(os.Args[0]) - if err != nil { - return err - } - - argv := []string{os.Args[0], "--foreground"} - argv = append(argv, os.Args[1:]...) - process, err := os.StartProcess( - path, - argv, - attr, - ) - if err != nil { - return err - } - process.Release() - return nil -} diff --git a/src/daemon_windows.go b/src/daemon_windows.go deleted file mode 100644 index d5ec1e8..0000000 --- a/src/daemon_windows.go +++ /dev/null @@ -1,34 +0,0 @@ -package main
-
-import (
- "os"
-)
-
-/* Daemonizes the process on windows
- *
- * This is done by spawning and releasing a copy with the --foreground flag
- */
-
-func Daemonize() error {
- argv := []string{os.Args[0], "--foreground"}
- argv = append(argv, os.Args[1:]...)
- attr := &os.ProcAttr{
- Dir: ".",
- Env: os.Environ(),
- Files: []*os.File{
- os.Stdin,
- nil,
- nil,
- },
- }
- process, err := os.StartProcess(
- argv[0],
- argv,
- attr,
- )
- if err != nil {
- return err
- }
- process.Release()
- return nil
-}
diff --git a/src/device.go b/src/device.go deleted file mode 100644 index c041987..0000000 --- a/src/device.go +++ /dev/null @@ -1,372 +0,0 @@ -package main - -import ( - "github.com/sasha-s/go-deadlock" - "runtime" - "sync" - "sync/atomic" - "time" -) - -type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger - - // synchronized resources (locks acquired in order) - - state struct { - mutex deadlock.Mutex - changing AtomicBool - current bool - } - - net struct { - mutex deadlock.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - } - - noise struct { - mutex deadlock.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey - } - - routing struct { - mutex deadlock.RWMutex - table RoutingTable - } - - peers struct { - mutex deadlock.RWMutex - keyMap map[NoisePublicKey]*Peer - } - - // unprotected / "self-synchronising resources" - - indices IndexTable - mac CookieChecker - - rate struct { - underLoadUntil atomic.Value - limiter Ratelimiter - } - - pool struct { - messageBuffers sync.Pool - } - - queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - - signal struct { - stop Signal - } - - tun struct { - device TUNDevice - mtu int32 - } -} - -/* Converts the peer into a "zombie", which remains in the peer map, - * but processes no packets and does not exists in the routing table. - * - * Must hold: - * device.peers.mutex : exclusive lock - * device.routing : exclusive lock - */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { - - // stop routing and processing of packets - - device.routing.table.RemovePeer(peer) - peer.Stop() - - // remove from peer map - - delete(device.peers.keyMap, key) -} - -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { - return - } - - func() { - - // compare to current state of device - - device.state.mutex.Lock() - defer device.state.mutex.Unlock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - return - } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.isUp.Set(false) - break - } - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for _, peer := range device.peers.keyMap { - peer.Start() - } - - case false: - device.BindClose() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for _, peer := range device.peers.keyMap { - println("stopping peer") - peer.Stop() - } - } - - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - }() - - // check for state change in the mean time - - deviceUpdateState(device) -} - -func (device *Device) Up() { - - // closed device cannot be brought up - - if device.isClosed.Get() { - return - } - - device.state.mutex.Lock() - device.isUp.Set(true) - device.state.mutex.Unlock() - deviceUpdateState(device) -} - -func (device *Device) Down() { - device.state.mutex.Lock() - device.isUp.Set(false) - device.state.mutex.Unlock() - deviceUpdateState(device) -} - -func (device *Device) IsUnderLoad() bool { - - // check if currently under load - - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize - if underLoad { - device.rate.underLoadUntil.Store(now.Add(time.Second)) - return true - } - - // check if recently under load - - until := device.rate.underLoadUntil.Load().(time.Time) - return until.After(now) -} - -func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { - - // lock required resources - - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for _, peer := range device.peers.keyMap { - peer.handshake.mutex.RLock() - defer peer.handshake.mutex.RUnlock() - } - - // remove peers with matching public keys - - publicKey := sk.publicKey() - for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equals(publicKey) { - unsafeRemovePeer(device, peer, key) - } - } - - // update key material - - device.noise.privateKey = sk - device.noise.publicKey = publicKey - device.mac.Init(publicKey) - - // do static-static DH pre-computations - - rmKey := device.noise.privateKey.IsZero() - - for key, peer := range device.peers.keyMap { - - hs := &peer.handshake - - if rmKey { - hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} - } else { - hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) - } - - if isZero(hs.precomputedStaticStatic[:]) { - unsafeRemovePeer(device, peer, key) - } - } - - return nil -} - -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) -} - -func NewDevice(tun TUNDevice, logger *Logger) *Device { - device := new(Device) - - device.isUp.Set(false) - device.isClosed.Set(false) - - device.log = logger - device.tun.device = tun - device.peers.keyMap = make(map[NoisePublicKey]*Peer) - - // initialize anti-DoS / anti-scanning features - - device.rate.limiter.Init() - device.rate.underLoadUntil.Store(time.Time{}) - - // initialize noise & crypt-key routine - - device.indices.Init() - device.routing.table.Reset() - - // setup buffer pool - - device.pool.messageBuffers = sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - - // create queues - - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signal.stop = NewSignal() - - // prepare net - - device.net.port = 0 - device.net.bind = nil - - // start workers - - for i := 0; i < runtime.NumCPU(); i += 1 { - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() - } - - go device.RoutineReadFromTUN() - go device.RoutineTUNEventReader() - go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) - - return device -} - -func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { - device.peers.mutex.RLock() - defer device.peers.mutex.RUnlock() - - return device.peers.keyMap[pk] -} - -func (device *Device) RemovePeer(key NoisePublicKey) { - device.noise.mutex.Lock() - defer device.noise.mutex.Unlock() - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // stop peer and remove from routing - - peer, ok := device.peers.keyMap[key] - if ok { - unsafeRemovePeer(device, peer, key) - } -} - -func (device *Device) RemoveAllPeers() { - - device.routing.mutex.Lock() - defer device.routing.mutex.Unlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - for key, peer := range device.peers.keyMap { - println("rm", peer.String()) - unsafeRemovePeer(device, peer, key) - } - - device.peers.keyMap = make(map[NoisePublicKey]*Peer) -} - -func (device *Device) Close() { - device.log.Info.Println("Device closing") - if device.isClosed.Swap(true) { - return - } - device.signal.stop.Broadcast() - device.tun.device.Close() - device.BindClose() - device.isUp.Set(false) - device.RemoveAllPeers() - device.log.Info.Println("Interface closed") -} - -func (device *Device) Wait() chan struct{} { - return device.signal.stop.Wait() -} diff --git a/src/helper_test.go b/src/helper_test.go deleted file mode 100644 index 41e6b72..0000000 --- a/src/helper_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "bytes" - "os" - "testing" -) - -/* Helpers for writing unit tests - */ - -type DummyTUN struct { - name string - mtu int - packets chan []byte - events chan TUNEvent -} - -func (tun *DummyTUN) File() *os.File { - return nil -} - -func (tun *DummyTUN) Name() string { - return tun.name -} - -func (tun *DummyTUN) MTU() (int, error) { - return tun.mtu, nil -} - -func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { - tun.packets <- d[offset:] - return len(d), nil -} - -func (tun *DummyTUN) Close() error { - return nil -} - -func (tun *DummyTUN) Events() chan TUNEvent { - return tun.events -} - -func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { - t := <-tun.packets - copy(d[offset:], t) - return len(t), nil -} - -func CreateDummyTUN(name string) (TUNDevice, error) { - var dummy DummyTUN - dummy.mtu = 0 - dummy.packets = make(chan []byte, 100) - return &dummy, nil -} - -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} - -func assertEqual(t *testing.T, a []byte, b []byte) { - if bytes.Compare(a, b) != 0 { - t.Fatal(a, "!=", b) - } -} - -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - tun, _ := CreateDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device -} diff --git a/src/index.go b/src/index.go deleted file mode 100644 index 1ba040e..0000000 --- a/src/index.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "crypto/rand" - "encoding/binary" - "sync" -) - -/* Index=0 is reserved for unset indecies - * - */ - -type IndexTableEntry struct { - peer *Peer - handshake *Handshake - keyPair *KeyPair -} - -type IndexTable struct { - mutex sync.RWMutex - table map[uint32]IndexTableEntry -} - -func randUint32() (uint32, error) { - var buff [4]byte - _, err := rand.Read(buff[:]) - value := binary.LittleEndian.Uint32(buff[:]) - return value, err -} - -func (table *IndexTable) Init() { - table.mutex.Lock() - table.table = make(map[uint32]IndexTableEntry) - table.mutex.Unlock() -} - -func (table *IndexTable) Delete(index uint32) { - if index == 0 { - return - } - table.mutex.Lock() - delete(table.table, index) - table.mutex.Unlock() -} - -func (table *IndexTable) Insert(key uint32, value IndexTableEntry) { - table.mutex.Lock() - table.table[key] = value - table.mutex.Unlock() -} - -func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { - for { - // generate random index - - index, err := randUint32() - if err != nil { - return index, err - } - if index == 0 { - continue - } - - // check if index used - - table.mutex.RLock() - _, ok := table.table[index] - table.mutex.RUnlock() - if ok { - continue - } - - // map index to handshake - - table.mutex.Lock() - _, found := table.table[index] - if found { - table.mutex.Unlock() - continue - } - table.table[index] = IndexTableEntry{ - peer: peer, - handshake: &peer.handshake, - keyPair: nil, - } - table.mutex.Unlock() - return index, nil - } -} - -func (table *IndexTable) Lookup(id uint32) IndexTableEntry { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.table[id] -} diff --git a/src/ip.go b/src/ip.go deleted file mode 100644 index 752a404..0000000 --- a/src/ip.go +++ /dev/null @@ -1,17 +0,0 @@ -package main - -import ( - "net" -) - -const ( - IPv4offsetTotalLength = 2 - IPv4offsetSrc = 12 - IPv4offsetDst = IPv4offsetSrc + net.IPv4len -) - -const ( - IPv6offsetPayloadLength = 4 - IPv6offsetSrc = 8 - IPv6offsetDst = IPv6offsetSrc + net.IPv6len -) diff --git a/src/kdf_test.go b/src/kdf_test.go deleted file mode 100644 index a89dacc..0000000 --- a/src/kdf_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "encoding/hex" - "golang.org/x/crypto/blake2s" - "testing" -) - -type KDFTest struct { - key string - input string - t0 string - t1 string - t2 string -} - -func assertEquals(t *testing.T, a string, b string) { - if a != b { - t.Fatal("expected", a, "=", b) - } -} - -func TestKDF(t *testing.T) { - tests := []KDFTest{ - { - key: "746573742d6b6579", - input: "746573742d696e707574", - t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", - t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", - t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", - }, - { - key: "776972656775617264", - input: "776972656775617264", - t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", - t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", - t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", - }, - { - key: "", - input: "", - t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", - t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", - t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", - }, - } - - var t0, t1, t2 [blake2s.Size]byte - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF3(&t0, &t1, &t2, key, input) - t0s := hex.EncodeToString(t0[:]) - t1s := hex.EncodeToString(t1[:]) - t2s := hex.EncodeToString(t2[:]) - assertEquals(t, t0s, test.t0) - assertEquals(t, t1s, test.t1) - assertEquals(t, t2s, test.t2) - } - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF2(&t0, &t1, key, input) - t0s := hex.EncodeToString(t0[:]) - t1s := hex.EncodeToString(t1[:]) - assertEquals(t, t0s, test.t0) - assertEquals(t, t1s, test.t1) - } - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF1(&t0, key, input) - t0s := hex.EncodeToString(t0[:]) - assertEquals(t, t0s, test.t0) - } -} diff --git a/src/keypair.go b/src/keypair.go deleted file mode 100644 index 283cb92..0000000 --- a/src/keypair.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -import ( - "crypto/cipher" - "sync" - "time" -) - -/* Due to limitations in Go and /x/crypto there is currently - * no way to ensure that key material is securely ereased in memory. - * - * Since this may harm the forward secrecy property, - * we plan to resolve this issue; whenever Go allows us to do so. - */ - -type KeyPair struct { - send cipher.AEAD - receive cipher.AEAD - replayFilter ReplayFilter - sendNonce uint64 - isInitiator bool - created time.Time - localIndex uint32 - remoteIndex uint32 -} - -type KeyPairs struct { - mutex sync.RWMutex - current *KeyPair - previous *KeyPair - next *KeyPair // not yet "confirmed by transport" -} - -func (kp *KeyPairs) Current() *KeyPair { - kp.mutex.RLock() - defer kp.mutex.RUnlock() - return kp.current -} - -func (device *Device) DeleteKeyPair(key *KeyPair) { - if key != nil { - device.indices.Delete(key.localIndex) - } -} diff --git a/src/logger.go b/src/logger.go deleted file mode 100644 index 0872ef9..0000000 --- a/src/logger.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "io" - "io/ioutil" - "log" - "os" -) - -const ( - LogLevelError = iota - LogLevelInfo - LogLevelDebug -) - -type Logger struct { - Debug *log.Logger - Info *log.Logger - Error *log.Logger -} - -func NewLogger(level int, prepend string) *Logger { - output := os.Stdout - logger := new(Logger) - - logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { - if level >= LogLevelDebug { - return output, output, output - } - if level >= LogLevelInfo { - return output, output, ioutil.Discard - } - return output, ioutil.Discard, ioutil.Discard - }() - - logger.Debug = log.New(logDebug, - "DEBUG: "+prepend, - log.Ldate|log.Ltime|log.Lshortfile, - ) - - logger.Info = log.New(logInfo, - "INFO: "+prepend, - log.Ldate|log.Ltime, - ) - logger.Error = log.New(logErr, - "ERROR: "+prepend, - log.Ldate|log.Ltime, - ) - return logger -} diff --git a/src/main.go b/src/main.go deleted file mode 100644 index b12bb09..0000000 --- a/src/main.go +++ /dev/null @@ -1,196 +0,0 @@ -package main - -import ( - "fmt" - "os" - "os/signal" - "runtime" - "strconv" -) - -const ( - ExitSetupSuccess = 0 - ExitSetupFailed = 1 -) - -const ( - ENV_WG_TUN_FD = "WG_TUN_FD" - ENV_WG_UAPI_FD = "WG_UAPI_FD" -) - -func printUsage() { - fmt.Printf("usage:\n") - fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) -} - -func main() { - - // parse arguments - - var foreground bool - var interfaceName string - if len(os.Args) < 2 || len(os.Args) > 3 { - printUsage() - return - } - - switch os.Args[1] { - - case "-f", "--foreground": - foreground = true - if len(os.Args) != 3 { - printUsage() - return - } - interfaceName = os.Args[2] - - default: - foreground = false - if len(os.Args) != 2 { - printUsage() - return - } - interfaceName = os.Args[1] - } - - // get log level (default: info) - - logLevel := func() int { - switch os.Getenv("LOG_LEVEL") { - case "debug": - return LogLevelDebug - case "info": - return LogLevelInfo - case "error": - return LogLevelError - } - return LogLevelInfo - }() - - logger := NewLogger( - logLevel, - fmt.Sprintf("(%s) ", interfaceName), - ) - - logger.Debug.Println("Debug log enabled") - - // open TUN device (or use supplied fd) - - tun, err := func() (TUNDevice, error) { - tunFdStr := os.Getenv(ENV_WG_TUN_FD) - if tunFdStr == "" { - return CreateTUN(interfaceName) - } - - // construct tun device from supplied fd - - fd, err := strconv.ParseUint(tunFdStr, 10, 32) - if err != nil { - return nil, err - } - - file := os.NewFile(uintptr(fd), "") - return CreateTUNFromFile(interfaceName, file) - }() - - if err != nil { - logger.Error.Println("Failed to create TUN device:", err) - os.Exit(ExitSetupFailed) - } - - // open UAPI file (or use supplied fd) - - fileUAPI, err := func() (*os.File, error) { - uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) - if uapiFdStr == "" { - return UAPIOpen(interfaceName) - } - - // use supplied fd - - fd, err := strconv.ParseUint(uapiFdStr, 10, 32) - if err != nil { - return nil, err - } - - return os.NewFile(uintptr(fd), ""), nil - }() - - if err != nil { - logger.Error.Println("UAPI listen error:", err) - os.Exit(ExitSetupFailed) - return - } - // daemonize the process - - if !foreground { - env := os.Environ() - env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) - env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) - attr := &os.ProcAttr{ - Files: []*os.File{ - nil, // stdin - nil, // stdout - nil, // stderr - tun.File(), - fileUAPI, - }, - Dir: ".", - Env: env, - } - err = Daemonize(attr) - if err != nil { - logger.Error.Println("Failed to daemonize:", err) - os.Exit(ExitSetupFailed) - } - return - } - - // increase number of go workers (for Go <1.5) - - runtime.GOMAXPROCS(runtime.NumCPU()) - - // create wireguard device - - device := NewDevice(tun, logger) - - logger.Info.Println("Device started") - - // start uapi listener - - errs := make(chan error) - term := make(chan os.Signal) - - uapi, err := UAPIListen(interfaceName, fileUAPI) - - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - errs <- err - return - } - go ipcHandle(device, conn) - } - }() - - logger.Info.Println("UAPI listener started") - - // wait for program to terminate - - signal.Notify(term, os.Kill) - signal.Notify(term, os.Interrupt) - - select { - case <-term: - case <-errs: - case <-device.Wait(): - } - - // clean up - - uapi.Close() - device.Close() - - logger.Info.Println("Shutting down") -} diff --git a/src/misc.go b/src/misc.go deleted file mode 100644 index 80e33f6..0000000 --- a/src/misc.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - flag int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.flag) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.flag, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.flag, flag) -} - -/* Integer manipulation */ - -func toInt32(n uint32) int32 { - mask := uint32(1 << 31) - return int32(-(n & mask) + (n & ^mask)) -} - -func min(a uint, b uint) uint { - if a > b { - return b - } - return a -} - -func minUint64(a uint64, b uint64) uint64 { - if a > b { - return b - } - return a -} diff --git a/src/noise_helpers.go b/src/noise_helpers.go deleted file mode 100644 index 1e2de5f..0000000 --- a/src/noise_helpers.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/subtle" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/curve25519" - "hash" -) - -/* KDF related functions. - * HMAC-based Key Derivation Function (HKDF) - * https://tools.ietf.org/html/rfc5869 - */ - -func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { - mac := hmac.New(func() hash.Hash { - h, _ := blake2s.New256(nil) - return h - }, key) - mac.Write(in0) - mac.Sum(sum[:0]) -} - -func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { - mac := hmac.New(func() hash.Hash { - h, _ := blake2s.New256(nil) - return h - }, key) - mac.Write(in0) - mac.Write(in1) - mac.Sum(sum[:0]) -} - -func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { - HMAC1(t0, key, input) - HMAC1(t0, t0[:], []byte{0x1}) - return -} - -func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { - var prk [blake2s.Size]byte - HMAC1(&prk, key, input) - HMAC1(t0, prk[:], []byte{0x1}) - HMAC2(t1, prk[:], t0[:], []byte{0x2}) - setZero(prk[:]) - return -} - -func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { - var prk [blake2s.Size]byte - HMAC1(&prk, key, input) - HMAC1(t0, prk[:], []byte{0x1}) - HMAC2(t1, prk[:], t0[:], []byte{0x2}) - HMAC2(t2, prk[:], t1[:], []byte{0x3}) - setZero(prk[:]) - return -} - -func isZero(val []byte) bool { - acc := 1 - for _, b := range val { - acc &= subtle.ConstantTimeByteEq(b, 0) - } - return acc == 1 -} - -func setZero(arr []byte) { - for i := range arr { - arr[i] = 0 - } -} - -/* curve25519 wrappers */ - -func newPrivateKey() (sk NoisePrivateKey, err error) { - // clamping: https://cr.yp.to/ecdh.html - _, err = rand.Read(sk[:]) - sk[0] &= 248 - sk[31] &= 127 - sk[31] |= 64 - return -} - -func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarBaseMult(apk, ask) - return -} - -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarMult(&ss, ask, apk) - return ss -} diff --git a/src/noise_protocol.go b/src/noise_protocol.go deleted file mode 100644 index c9713c0..0000000 --- a/src/noise_protocol.go +++ /dev/null @@ -1,578 +0,0 @@ -package main - -import ( - "errors" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/poly1305" - "sync" - "time" -) - -const ( - HandshakeZeroed = iota - HandshakeInitiationCreated - HandshakeInitiationConsumed - HandshakeResponseCreated - HandshakeResponseConsumed -) - -const ( - NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" - WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" - WGLabelMAC1 = "mac1----" - WGLabelCookie = "cookie--" -) - -const ( - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 -) - -const ( - MessageInitiationSize = 148 // size of handshake initation message - MessageResponseSize = 92 // size of response message - MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceeding content in transport message - MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport - MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message -) - -const ( - MessageTransportOffsetReceiver = 4 - MessageTransportOffsetCounter = 8 - MessageTransportOffsetContent = 16 -) - -/* Type is an 8-bit field, followed by 3 nul bytes, - * by marshalling the messages in little-endian byteorder - * we can treat these as a 32-bit unsigned int (for now) - * - */ - -type MessageInitiation struct { - Type uint32 - Sender uint32 - Ephemeral NoisePublicKey - Static [NoisePublicKeySize + poly1305.TagSize]byte - Timestamp [TAI64NSize + poly1305.TagSize]byte - MAC1 [blake2s.Size128]byte - MAC2 [blake2s.Size128]byte -} - -type MessageResponse struct { - Type uint32 - Sender uint32 - Receiver uint32 - Ephemeral NoisePublicKey - Empty [poly1305.TagSize]byte - MAC1 [blake2s.Size128]byte - MAC2 [blake2s.Size128]byte -} - -type MessageTransport struct { - Type uint32 - Receiver uint32 - Counter uint64 - Content []byte -} - -type MessageCookieReply struct { - Type uint32 - Receiver uint32 - Nonce [24]byte - Cookie [blake2s.Size128 + poly1305.TagSize]byte -} - -type Handshake struct { - state int - mutex sync.RWMutex - hash [blake2s.Size]byte // hash value - chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk - localEphemeral NoisePrivateKey // ephemeral secret key - localIndex uint32 // used to clear hash-table - remoteIndex uint32 // index for sending - remoteStatic NoisePublicKey // long term key - remoteEphemeral NoisePublicKey // ephemeral public key - precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret - lastTimestamp TAI64N - lastInitiationConsumption time.Time -} - -var ( - InitialChainKey [blake2s.Size]byte - InitialHash [blake2s.Size]byte - ZeroNonce [chacha20poly1305.NonceSize]byte -) - -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { - KDF1(dst, c[:], data) -} - -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { - hsh, _ := blake2s.New256(nil) - hsh.Write(h[:]) - hsh.Write(data) - hsh.Sum(dst[:0]) - hsh.Reset() -} - -func (h *Handshake) Clear() { - setZero(h.localEphemeral[:]) - setZero(h.remoteEphemeral[:]) - setZero(h.chainKey[:]) - setZero(h.hash[:]) - h.localIndex = 0 - h.state = HandshakeZeroed -} - -func (h *Handshake) mixHash(data []byte) { - mixHash(&h.hash, &h.hash, data) -} - -func (h *Handshake) mixKey(data []byte) { - mixKey(&h.chainKey, &h.chainKey, data) -} - -/* Do basic precomputations - */ -func init() { - InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) - mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) -} - -func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("Static shared secret is zero") - } - - // create ephemeral key - - var err error - handshake.hash = InitialHash - handshake.chainKey = InitialChainKey - handshake.localEphemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - - // assign index - - device.indices.Delete(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) - - if err != nil { - return nil, err - } - - handshake.mixHash(handshake.remoteStatic[:]) - - msg := MessageInitiation{ - Type: MessageInitiationType, - Ephemeral: handshake.localEphemeral.publicKey(), - Sender: handshake.localIndex, - } - - handshake.mixKey(msg.Ephemeral[:]) - handshake.mixHash(msg.Ephemeral[:]) - - // encrypt static key - - func() { - var key [chacha20poly1305.KeySize]byte - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) - }() - handshake.mixHash(msg.Static[:]) - - // encrypt timestamp - - timestamp := Timestamp() - func() { - var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - }() - - handshake.mixHash(msg.Timestamp[:]) - handshake.state = HandshakeInitiationCreated - return &msg, nil -} - -func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { - var ( - hash [blake2s.Size]byte - chainKey [blake2s.Size]byte - ) - - if msg.Type != MessageInitiationType { - return nil - } - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - mixHash(&hash, &InitialHash, device.noise.publicKey[:]) - mixHash(&hash, &hash, msg.Ephemeral[:]) - mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) - - // decrypt static key - - var err error - var peerPK NoisePublicKey - func() { - var key [chacha20poly1305.KeySize]byte - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) - KDF2(&chainKey, &key, chainKey[:], ss[:]) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) - }() - if err != nil { - return nil - } - mixHash(&hash, &hash, msg.Static[:]) - - // lookup peer - - peer := device.LookupPeer(peerPK) - if peer == nil { - return nil - } - - handshake := &peer.handshake - if isZero(handshake.precomputedStaticStatic[:]) { - return nil - } - - // verify identity - - var timestamp TAI64N - var key [chacha20poly1305.KeySize]byte - - handshake.mutex.RLock() - KDF2( - &chainKey, - &key, - chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) - if err != nil { - handshake.mutex.RUnlock() - return nil - } - mixHash(&hash, &hash, msg.Timestamp[:]) - - // protect against replay & flood - - var ok bool - ok = timestamp.After(handshake.lastTimestamp) - ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate - handshake.mutex.RUnlock() - if !ok { - return nil - } - - // update handshake state - - handshake.mutex.Lock() - - handshake.hash = hash - handshake.chainKey = chainKey - handshake.remoteIndex = msg.Sender - handshake.remoteEphemeral = msg.Ephemeral - handshake.lastTimestamp = timestamp - handshake.lastInitiationConsumption = time.Now() - handshake.state = HandshakeInitiationConsumed - - handshake.mutex.Unlock() - - return peer -} - -func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - if handshake.state != HandshakeInitiationConsumed { - return nil, errors.New("handshake initation must be consumed first") - } - - // assign index - - var err error - device.indices.Delete(handshake.localIndex) - handshake.localIndex, err = device.indices.NewIndex(peer) - if err != nil { - return nil, err - } - - var msg MessageResponse - msg.Type = MessageResponseType - msg.Sender = handshake.localIndex - msg.Receiver = handshake.remoteIndex - - // create ephemeral key - - handshake.localEphemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - msg.Ephemeral = handshake.localEphemeral.publicKey() - handshake.mixHash(msg.Ephemeral[:]) - handshake.mixKey(msg.Ephemeral[:]) - - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() - - // add preshared key (psk) - - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - - KDF3( - &handshake.chainKey, - &tau, - &key, - handshake.chainKey[:], - handshake.presharedKey[:], - ) - - handshake.mixHash(tau[:]) - - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() - - handshake.state = HandshakeResponseCreated - - return &msg, nil -} - -func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - if msg.Type != MessageResponseType { - return nil - } - - // lookup handshake by reciever - - lookup := device.indices.Lookup(msg.Receiver) - handshake := lookup.handshake - if handshake == nil { - return nil - } - - var ( - hash [blake2s.Size]byte - chainKey [blake2s.Size]byte - ) - - ok := func() bool { - - // lock handshake state - - handshake.mutex.RLock() - defer handshake.mutex.RUnlock() - - if handshake.state != HandshakeInitiationCreated { - return false - } - - // lock private key for reading - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - // finish 3-way DH - - mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) - mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() - - func() { - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() - - // add preshared key (psk) - - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - KDF3( - &chainKey, - &tau, - &key, - chainKey[:], - handshake.presharedKey[:], - ) - mixHash(&hash, &hash, tau[:]) - - // authenticate transcript - - aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) - if err != nil { - device.log.Debug.Println("failed to open") - return false - } - mixHash(&hash, &hash, msg.Empty[:]) - return true - }() - - if !ok { - return nil - } - - // update handshake state - - handshake.mutex.Lock() - - handshake.hash = hash - handshake.chainKey = chainKey - handshake.remoteIndex = msg.Sender - handshake.state = HandshakeResponseConsumed - - handshake.mutex.Unlock() - - setZero(hash[:]) - setZero(chainKey[:]) - - return lookup.peer -} - -/* Derives a new key-pair from the current handshake state - * - */ -func (peer *Peer) NewKeyPair() *KeyPair { - device := peer.device - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - // derive keys - - var isInitiator bool - var sendKey [chacha20poly1305.KeySize]byte - var recvKey [chacha20poly1305.KeySize]byte - - if handshake.state == HandshakeResponseConsumed { - KDF2( - &sendKey, - &recvKey, - handshake.chainKey[:], - nil, - ) - isInitiator = true - } else if handshake.state == HandshakeResponseCreated { - KDF2( - &recvKey, - &sendKey, - handshake.chainKey[:], - nil, - ) - isInitiator = false - } else { - return nil - } - - // zero handshake - - setZero(handshake.chainKey[:]) - setZero(handshake.localEphemeral[:]) - peer.handshake.state = HandshakeZeroed - - // create AEAD instances - - keyPair := new(KeyPair) - keyPair.send, _ = chacha20poly1305.New(sendKey[:]) - keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) - - setZero(sendKey[:]) - setZero(recvKey[:]) - - keyPair.created = time.Now() - keyPair.sendNonce = 0 - keyPair.replayFilter.Init() - keyPair.isInitiator = isInitiator - keyPair.localIndex = peer.handshake.localIndex - keyPair.remoteIndex = peer.handshake.remoteIndex - - // remap index - - device.indices.Insert( - handshake.localIndex, - IndexTableEntry{ - peer: peer, - keyPair: keyPair, - handshake: nil, - }, - ) - handshake.localIndex = 0 - - // rotate key pairs - - kp := &peer.keyPairs - kp.mutex.Lock() - - if isInitiator { - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - kp.previous = nil - } - - if kp.next != nil { - kp.previous = kp.next - kp.next = keyPair - } else { - kp.previous = kp.current - kp.current = keyPair - peer.signal.newKeyPair.Send() - } - - } else { - kp.next = keyPair - kp.previous = nil - } - kp.mutex.Unlock() - - return keyPair -} diff --git a/src/noise_test.go b/src/noise_test.go deleted file mode 100644 index 5e9d44b..0000000 --- a/src/noise_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "testing" -) - -func TestCurveWrappers(t *testing.T) { - sk1, err := newPrivateKey() - assertNil(t, err) - - sk2, err := newPrivateKey() - assertNil(t, err) - - pk1 := sk1.publicKey() - pk2 := sk2.publicKey() - - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) - - if ss1 != ss2 { - t.Fatal("Failed to compute shared secet") - } -} - -func TestNoiseHandshake(t *testing.T) { - dev1 := randDevice(t) - dev2 := randDevice(t) - - defer dev1.Close() - defer dev2.Close() - - peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey()) - - assertEqual( - t, - peer1.handshake.precomputedStaticStatic[:], - peer2.handshake.precomputedStaticStatic[:], - ) - - /* simulate handshake */ - - // initiation message - - t.Log("exchange initiation message") - - msg1, err := dev1.CreateMessageInitiation(peer2) - assertNil(t, err) - - packet := make([]byte, 0, 256) - writer := bytes.NewBuffer(packet) - err = binary.Write(writer, binary.LittleEndian, msg1) - peer := dev2.ConsumeMessageInitiation(msg1) - if peer == nil { - t.Fatal("handshake failed at initiation message") - } - - assertEqual( - t, - peer1.handshake.chainKey[:], - peer2.handshake.chainKey[:], - ) - - assertEqual( - t, - peer1.handshake.hash[:], - peer2.handshake.hash[:], - ) - - // response message - - t.Log("exchange response message") - - msg2, err := dev2.CreateMessageResponse(peer1) - assertNil(t, err) - - peer = dev1.ConsumeMessageResponse(msg2) - if peer == nil { - t.Fatal("handshake failed at response message") - } - - assertEqual( - t, - peer1.handshake.chainKey[:], - peer2.handshake.chainKey[:], - ) - - assertEqual( - t, - peer1.handshake.hash[:], - peer2.handshake.hash[:], - ) - - // key pairs - - t.Log("deriving keys") - - key1 := peer1.NewKeyPair() - key2 := peer2.NewKeyPair() - - if key1 == nil { - t.Fatal("failed to dervice key-pair for peer 1") - } - - if key2 == nil { - t.Fatal("failed to dervice key-pair for peer 2") - } - - // encrypting / decryption test - - t.Log("test key pairs") - - func() { - testMsg := []byte("wireguard test message 1") - var err error - var out []byte - var nonce [12]byte - out = key1.send.Seal(out, nonce[:], testMsg, nil) - out, err = key2.receive.Open(out[:0], nonce[:], out, nil) - assertNil(t, err) - assertEqual(t, out, testMsg) - }() - - func() { - testMsg := []byte("wireguard test message 2") - var err error - var out []byte - var nonce [12]byte - out = key2.send.Seal(out, nonce[:], testMsg, nil) - out, err = key1.receive.Open(out[:0], nonce[:], out, nil) - assertNil(t, err) - assertEqual(t, out, testMsg) - }() -} diff --git a/src/noise_types.go b/src/noise_types.go deleted file mode 100644 index 1a944df..0000000 --- a/src/noise_types.go +++ /dev/null @@ -1,74 +0,0 @@ -package main - -import ( - "crypto/subtle" - "encoding/hex" - "errors" - "golang.org/x/crypto/chacha20poly1305" -) - -const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 -) - -type ( - NoisePublicKey [NoisePublicKeySize]byte - NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [chacha20poly1305.KeySize]byte - NoiseNonce uint64 // padded to 12-bytes -) - -func loadExactHex(dst []byte, src string) error { - slice, err := hex.DecodeString(src) - if err != nil { - return err - } - if len(slice) != len(dst) { - return errors.New("Hex string does not fit the slice") - } - copy(dst, slice) - return nil -} - -func (key NoisePrivateKey) IsZero() bool { - var zero NoisePrivateKey - return key.Equals(zero) -} - -func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoisePrivateKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoisePrivateKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key *NoisePublicKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoisePublicKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key NoisePublicKey) IsZero() bool { - var zero NoisePublicKey - return key.Equals(zero) -} - -func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoiseSymmetricKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoiseSymmetricKey) ToHex() string { - return hex.EncodeToString(key[:]) -} diff --git a/src/peer.go b/src/peer.go deleted file mode 100644 index dc04811..0000000 --- a/src/peer.go +++ /dev/null @@ -1,295 +0,0 @@ -package main - -import ( - "encoding/base64" - "errors" - "fmt" - "github.com/sasha-s/go-deadlock" - "sync" - "time" -) - -const ( - PeerRoutineNumber = 4 -) - -type Peer struct { - isRunning AtomicBool - mutex deadlock.RWMutex - persistentKeepaliveInterval uint64 - keyPairs KeyPairs - handshake Handshake - device *Device - endpoint Endpoint - - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch - } - - time struct { - mutex deadlock.RWMutex - lastSend time.Time // last send message - lastHandshake time.Time // last completed handshake - nextKeepalive time.Time - } - - signal struct { - newKeyPair Signal // size 1, new key pair was generated - handshakeCompleted Signal // size 1, handshake completed - handshakeBegin Signal // size 1, begin new handshake begin - flushNonceQueue Signal // size 1, empty queued packets - messageSend Signal // size 1, message was send to peer - messageReceived Signal // size 1, authenticated message recv - } - - timer struct { - - // state related to WireGuard timers - - keepalivePersistent Timer // set for persistent keep-alive - keepalivePassive Timer // set upon receiving messages - zeroAllKeys Timer // zero all key material - handshakeNew Timer // begin a new handshake (stale) - handshakeDeadline Timer // complete handshake timeout - handshakeTimeout Timer // current handshake message timeout - - sendLastMinuteHandshake bool - needAnotherKeepalive bool - } - - queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - } - - routines struct { - mutex deadlock.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop Signal // size 0, stop all go-routines in peer - } - - mac CookieGenerator -} - -func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - - if device.isClosed.Get() { - return nil, errors.New("Device closed") - } - - // lock resources - - device.state.mutex.Lock() - defer device.state.mutex.Unlock() - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // check if over limit - - if len(device.peers.keyMap) >= MaxPeers { - return nil, errors.New("Too many peers") - } - - // create peer - - peer := new(Peer) - peer.mutex.Lock() - defer peer.mutex.Unlock() - - peer.mac.Init(pk) - peer.device = device - peer.isRunning.Set(false) - - peer.timer.zeroAllKeys = NewTimer() - peer.timer.keepalivePersistent = NewTimer() - peer.timer.keepalivePassive = NewTimer() - peer.timer.handshakeNew = NewTimer() - peer.timer.handshakeDeadline = NewTimer() - peer.timer.handshakeTimeout = NewTimer() - - // map public key - - _, ok := device.peers.keyMap[pk] - if ok { - return nil, errors.New("Adding existing peer") - } - device.peers.keyMap[pk] = peer - - // pre-compute DH - - handshake := &peer.handshake - handshake.mutex.Lock() - handshake.remoteStatic = pk - handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) - handshake.mutex.Unlock() - - // reset endpoint - - peer.endpoint = nil - - // prepare signaling & routines - - peer.routines.mutex.Lock() - peer.routines.stop = NewSignal() - peer.routines.mutex.Unlock() - - // start peer - - if peer.device.isUp.Get() { - peer.Start() - } - - return peer, nil -} - -func (peer *Peer) SendBuffer(buffer []byte) error { - peer.device.net.mutex.RLock() - defer peer.device.net.mutex.RUnlock() - - if peer.device.net.bind == nil { - return errors.New("No bind") - } - - peer.mutex.RLock() - defer peer.mutex.RUnlock() - - if peer.endpoint == nil { - return errors.New("No known endpoint for peer") - } - - return peer.device.net.bind.Send(buffer, peer.endpoint) -} - -/* Returns a short string identifier for logging - */ -func (peer *Peer) String() string { - if peer.endpoint == nil { - return fmt.Sprintf( - "peer(unknown %s)", - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) - } - return fmt.Sprintf( - "peer(%s %s)", - peer.endpoint.DstToString(), - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) -} - -func (peer *Peer) Start() { - - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { - return - } - - // prevent simultaneous start/stop operations - - peer.routines.mutex.Lock() - defer peer.routines.mutex.Unlock() - - if peer.isRunning.Get() { - return - } - - peer.device.log.Debug.Println("Starting:", peer.String()) - - // sanity check : these should be 0 - - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - - // prepare queues and signals - - peer.signal.newKeyPair = NewSignal() - peer.signal.handshakeBegin = NewSignal() - peer.signal.handshakeCompleted = NewSignal() - peer.signal.flushNonceQueue = NewSignal() - - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - - peer.routines.stop = NewSignal() - peer.isRunning.Set(true) - - // wait for routines to start - - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) - - go peer.RoutineNonce() - go peer.RoutineTimerHandler() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - - peer.routines.starting.Wait() - peer.isRunning.Set(true) -} - -func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations - - peer.routines.mutex.Lock() - defer peer.routines.mutex.Unlock() - - if !peer.isRunning.Swap(false) { - return - } - - device := peer.device - device.log.Debug.Println("Stopping:", peer.String()) - - // stop & wait for ongoing peer routines - - peer.routines.stop.Broadcast() - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - - // stop timers - - peer.timer.keepalivePersistent.Stop() - peer.timer.keepalivePassive.Stop() - peer.timer.zeroAllKeys.Stop() - peer.timer.handshakeNew.Stop() - peer.timer.handshakeDeadline.Stop() - peer.timer.handshakeTimeout.Stop() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) - - // clear key pairs - - kp := &peer.keyPairs - kp.mutex.Lock() - - device.DeleteKeyPair(kp.previous) - device.DeleteKeyPair(kp.current) - device.DeleteKeyPair(kp.next) - - kp.previous = nil - kp.current = nil - kp.next = nil - kp.mutex.Unlock() - - // clear handshake state - - hs := &peer.handshake - hs.mutex.Lock() - device.indices.Delete(hs.localIndex) - hs.Clear() - hs.mutex.Unlock() -} diff --git a/src/ratelimiter.go b/src/ratelimiter.go deleted file mode 100644 index 6e5f005..0000000 --- a/src/ratelimiter.go +++ /dev/null @@ -1,139 +0,0 @@ -package main - -/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ - -/* This file contains a port of the ratelimited from the linux kernel version - */ - -import ( - "net" - "sync" - "time" -) - -const ( - RatelimiterPacketsPerSecond = 20 - RatelimiterPacketsBurstable = 5 - RatelimiterGarbageCollectTime = time.Second - RatelimiterPacketCost = 1000000000 / RatelimiterPacketsPerSecond - RatelimiterMaxTokens = RatelimiterPacketCost * RatelimiterPacketsBurstable -) - -type RatelimiterEntry struct { - mutex sync.Mutex - lastTime time.Time - tokens int64 -} - -type Ratelimiter struct { - mutex sync.RWMutex - lastGarbageCollect time.Time - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry -} - -func (rate *Ratelimiter) Init() { - rate.mutex.Lock() - defer rate.mutex.Unlock() - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) - rate.lastGarbageCollect = time.Now() -} - -func (rate *Ratelimiter) GarbageCollectEntries() { - rate.mutex.Lock() - - // remove unused IPv4 entries - - for key, entry := range rate.tableIPv4 { - entry.mutex.Lock() - if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { - delete(rate.tableIPv4, key) - } - entry.mutex.Unlock() - } - - // remove unused IPv6 entries - - for key, entry := range rate.tableIPv6 { - entry.mutex.Lock() - if time.Now().Sub(entry.lastTime) > RatelimiterGarbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mutex.Unlock() - } - - rate.mutex.Unlock() -} - -func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) { - timer := time.NewTimer(time.Second) - for { - select { - case <-stop.Wait(): - return - case <-timer.C: - rate.GarbageCollectEntries() - timer.Reset(time.Second) - } - } -} - -func (rate *Ratelimiter) Allow(ip net.IP) bool { - var entry *RatelimiterEntry - var KeyIPv4 [net.IPv4len]byte - var KeyIPv6 [net.IPv6len]byte - - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - - rate.mutex.RLock() - - if IPv4 != nil { - copy(KeyIPv4[:], IPv4) - entry = rate.tableIPv4[KeyIPv4] - } else { - copy(KeyIPv6[:], IPv6) - entry = rate.tableIPv6[KeyIPv6] - } - - rate.mutex.RUnlock() - - // make new entry if not found - - if entry == nil { - rate.mutex.Lock() - entry = new(RatelimiterEntry) - entry.tokens = RatelimiterMaxTokens - RatelimiterPacketCost - entry.lastTime = time.Now() - if IPv4 != nil { - rate.tableIPv4[KeyIPv4] = entry - } else { - rate.tableIPv6[KeyIPv6] = entry - } - rate.mutex.Unlock() - return true - } - - // add tokens to entry - - entry.mutex.Lock() - now := time.Now() - entry.tokens += now.Sub(entry.lastTime).Nanoseconds() - entry.lastTime = now - if entry.tokens > RatelimiterMaxTokens { - entry.tokens = RatelimiterMaxTokens - } - - // subtract cost of packet - - if entry.tokens > RatelimiterPacketCost { - entry.tokens -= RatelimiterPacketCost - entry.mutex.Unlock() - return true - } - entry.mutex.Unlock() - return false -} diff --git a/src/ratelimiter_test.go b/src/ratelimiter_test.go deleted file mode 100644 index 13b6a23..0000000 --- a/src/ratelimiter_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "net" - "testing" - "time" -) - -type RatelimiterResult struct { - allowed bool - text string - wait time.Duration -} - -func TestRatelimiter(t *testing.T) { - - var ratelimiter Ratelimiter - var expectedResults []RatelimiterResult - - Nano := func(nano int64) time.Duration { - return time.Nanosecond * time.Duration(nano) - } - - Add := func(res RatelimiterResult) { - expectedResults = append( - expectedResults, - res, - ) - } - - for i := 0; i < RatelimiterPacketsBurstable; i++ { - Add(RatelimiterResult{ - allowed: true, - text: "inital burst", - }) - } - - Add(RatelimiterResult{ - allowed: false, - text: "after burst", - }) - - Add(RatelimiterResult{ - allowed: true, - wait: Nano(time.Second.Nanoseconds() / RatelimiterPacketsPerSecond), - text: "filling tokens for single packet", - }) - - Add(RatelimiterResult{ - allowed: false, - text: "not having refilled enough", - }) - - Add(RatelimiterResult{ - allowed: true, - wait: 2 * Nano(time.Second.Nanoseconds()/RatelimiterPacketsPerSecond), - text: "filling tokens for two packet burst", - }) - - Add(RatelimiterResult{ - allowed: true, - text: "second packet in 2 packet burst", - }) - - Add(RatelimiterResult{ - allowed: false, - text: "packet following 2 packet burst", - }) - - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), - } - - ratelimiter.Init() - - for i, res := range expectedResults { - time.Sleep(res.wait) - for _, ip := range ips { - allowed := ratelimiter.Allow(ip) - if allowed != res.allowed { - t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed) - } - } - } -} diff --git a/src/receive.go b/src/receive.go deleted file mode 100644 index 1f44df2..0000000 --- a/src/receive.go +++ /dev/null @@ -1,642 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" - "sync" - "sync/atomic" - "time" -) - -type QueueHandshakeElement struct { - msgType uint32 - packet []byte - endpoint Endpoint - buffer *[MaxMessageSize]byte -} - -type QueueInboundElement struct { - dropped int32 - mutex sync.Mutex - buffer *[MaxMessageSize]byte - packet []byte - counter uint64 - keyPair *KeyPair - endpoint Endpoint -} - -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundQueue( - queue chan *QueueInboundElement, - element *QueueInboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - old.Drop() - default: - } - } - } -} - -func (device *Device) addToDecryptionQueue( - queue chan *QueueInboundElement, - element *QueueInboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - // drop & release to potential consumer - old.Drop() - old.mutex.Unlock() - default: - } - } - } -} - -func (device *Device) addToHandshakeQueue( - queue chan QueueHandshakeElement, - element QueueHandshakeElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case elem := <-queue: - device.PutMessageBuffer(elem.buffer) - default: - } - } - } -} - -/* Receives incoming datagrams for the device - * - * Every time the bind is updated a new routine is started for - * IPv4 and IPv6 (separately) - */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { - - logDebug := device.log.Debug - logDebug.Println("Routine, receive incoming, IP version:", IP) - - // receive datagrams until conn is closed - - buffer := device.GetMessageBuffer() - - var ( - err error - size int - endpoint Endpoint - ) - - for { - - // read next datagram - - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") - } - - if err != nil { - return - } - - if size < MinMessageSize { - continue - } - - // check size of packet - - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) - - var okay bool - - switch msgType { - - // check if transport - - case MessageTransportType: - - // check size - - if len(packet) < MessageTransportType { - continue - } - - // lookup key pair - - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indices.Lookup(receiver) - keyPair := value.keyPair - if keyPair == nil { - continue - } - - // check key-pair expiry - - if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } - - // create work element - - peer := value.peer - elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keyPair: keyPair, - dropped: AtomicFalse, - endpoint: endpoint, - } - elem.mutex.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() - } - - continue - - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize - - case MessageResponseType: - okay = len(packet) == MessageResponseSize - - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize - } - - if okay { - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - ) - buffer = device.GetMessageBuffer() - } - } -} - -func (device *Device) RoutineDecryption() { - - var nonce [chacha20poly1305.NonceSize]byte - - logDebug := device.log.Debug - logDebug.Println("Routine, decryption, started for device") - - for { - select { - case <-device.signal.stop.Wait(): - logDebug.Println("Routine, decryption worker, stopped") - return - - case elem := <-device.queue.decryption: - - // check if dropped - - if elem.IsDropped() { - continue - } - - // split message into fields - - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - - // decrypt and release to consumer - - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - elem.packet, err = elem.keyPair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.Drop() - } - elem.mutex.Unlock() - } - } -} - -/* Handles incoming packets related to handshake - */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, handshake routine, started for device") - - var temp [MessageHandshakeSize]byte - var elem QueueHandshakeElement - - for { - select { - case elem = <-device.queue.handshake: - case <-device.signal.stop.Wait(): - return - } - - // handle cookie fields and ratelimiting - - switch elem.msgType { - - case MessageCookieReplyType: - - // unmarshal packet - - var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) - if err != nil { - logDebug.Println("Failed to decode cookie reply") - return - } - - // lookup peer from index - - entry := device.indices.Lookup(reply.Receiver) - - if entry.peer == nil { - continue - } - - // consume reply - - if peer := entry.peer; peer.isRunning.Get() { - peer.mac.ConsumeReply(&reply) - } - - continue - - case MessageInitiationType, MessageResponseType: - - // check mac fields and ratelimit - - if !device.mac.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue - } - - // endpoints destination address is the source of the datagram - - srcBytes := elem.endpoint.DstToBytes() - - if device.IsUnderLoad() { - - // verify MAC2 field - - if !device.mac.CheckMAC2(elem.packet, srcBytes) { - - // construct cookie reply - - logDebug.Println( - "Sending cookie reply to:", - elem.endpoint.DstToString(), - ) - - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) - reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) - if err != nil { - logError.Println("Failed to create cookie reply:", err) - continue - } - - // marshal and send reply - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), elem.endpoint) - if err != nil { - logDebug.Println("Failed to send cookie reply:", err) - } - continue - } - - // check ratelimiter - - if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue - } - } - - default: - logError.Println("Invalid packet ended up in the handshake queue") - continue - } - - // handle handshake initiation/response content - - switch elem.msgType { - case MessageInitiationType: - - // unmarshal - - var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode initiation message") - continue - } - - // consume initiation - - peer := device.ConsumeMessageInitiation(&msg) - if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue - } - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() - - // create response - - response, err := device.CreateMessageResponse(peer) - if err != nil { - logError.Println("Failed to create response message:", err) - continue - } - - peer.TimerEphemeralKeyCreated() - peer.NewKeyPair() - - logDebug.Println("Creating response message for", peer.String()) - - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send response - - err = peer.SendBuffer(packet) - if err == nil { - peer.TimerAnyAuthenticatedPacketTraversal() - } else { - logError.Println("Failed to send response to:", peer.String(), err) - } - - case MessageResponseType: - - // unmarshal - - var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode response message") - continue - } - - // consume response - - peer := device.ConsumeMessageResponse(&msg) - if peer == nil { - logInfo.Println( - "Recieved invalid response message from", - elem.endpoint.DstToString(), - ) - continue - } - - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() - - logDebug.Println("Received handshake initiation from", peer) - - peer.TimerEphemeralKeyCreated() - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.TimerHandshakeComplete() - - // derive key-pair - - peer.NewKeyPair() - peer.SendKeepAlive() - } - } -} - -func (peer *Peer) RoutineSequentialReceiver() { - - defer peer.routines.stopping.Done() - - device := peer.device - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, sequential receiver, started for peer", peer.String()) - - peer.routines.starting.Done() - - for { - - select { - - case <-peer.routines.stop.Wait(): - logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String()) - return - - case elem := <-peer.queue.inbound: - - // wait for decryption - - elem.mutex.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { - continue - } - - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.KeepKeyFreshReceiving() - - // check if using new key-pair - - kp := &peer.keyPairs - kp.mutex.Lock() - if kp.next == elem.keyPair { - peer.TimerHandshakeComplete() - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - } - kp.previous = kp.current - kp.current = kp.next - kp.next = nil - } - kp.mutex.Unlock() - - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() - - // check for keep-alive - - if len(elem.packet) == 0 { - logDebug.Println("Received keep-alive from", peer.String()) - continue - } - peer.TimerDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { - continue - } - - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routing.table.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer.String(), - ) - continue - } - - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { - continue - } - - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv6 source - - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routing.table.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer.String(), - ) - continue - } - - default: - logInfo.Println("Packet with invalid IP version from", peer.String()) - continue - } - - // write to tun device - - offset := MessageTransportOffsetContent - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write( - elem.buffer[:offset+len(elem.packet)], - offset) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logError.Println("Failed to write packet to TUN device:", err) - } - } - } -} diff --git a/src/replay.go b/src/replay.go deleted file mode 100644 index 5d42860..0000000 --- a/src/replay.go +++ /dev/null @@ -1,73 +0,0 @@ -package main - -/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ - -/* Implementation of RFC6479 - * https://tools.ietf.org/html/rfc6479 - * - * The implementation is not safe for concurrent use! - */ - -const ( - // See: https://golang.org/src/math/big/arith.go - _Wordm = ^uintptr(0) - _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 - _WordSize = 1 << _WordLogSize -) - -const ( - CounterRedundantBitsLog = _WordLogSize + 3 - CounterRedundantBits = _WordSize * 8 - CounterBitsTotal = 2048 - CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) -) - -const ( - BacktrackWords = CounterBitsTotal / _WordSize -) - -type ReplayFilter struct { - counter uint64 - backtrack [BacktrackWords]uintptr -} - -func (filter *ReplayFilter) Init() { - filter.counter = 0 - filter.backtrack[0] = 0 -} - -func (filter *ReplayFilter) ValidateCounter(counter uint64) bool { - if counter >= RejectAfterMessages { - return false - } - - indexWord := counter >> CounterRedundantBitsLog - - if counter > filter.counter { - - // move window forward - - current := filter.counter >> CounterRedundantBitsLog - diff := minUint64(indexWord-current, BacktrackWords) - for i := uint64(1); i <= diff; i++ { - filter.backtrack[(current+i)%BacktrackWords] = 0 - } - filter.counter = counter - - } else if filter.counter-counter > CounterWindowSize { - - // behind current window - - return false - } - - indexWord %= BacktrackWords - indexBit := counter & uint64(CounterRedundantBits-1) - - // check and set bit - - oldValue := filter.backtrack[indexWord] - newValue := oldValue | (1 << indexBit) - filter.backtrack[indexWord] = newValue - return oldValue != newValue -} diff --git a/src/replay_test.go b/src/replay_test.go deleted file mode 100644 index 228fce6..0000000 --- a/src/replay_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package main - -import ( - "testing" -) - -/* Ported from the linux kernel implementation - * - * - */ - -func TestReplay(t *testing.T) { - var filter ReplayFilter - - T_LIM := CounterWindowSize + 1 - - testNumber := 0 - T := func(n uint64, v bool) { - testNumber++ - if filter.ValidateCounter(n) != v { - t.Fatal("Test", testNumber, "failed", n, v) - } - } - - filter.Init() - - /* 1 */ T(0, true) - /* 2 */ T(1, true) - /* 3 */ T(1, false) - /* 4 */ T(9, true) - /* 5 */ T(8, true) - /* 6 */ T(7, true) - /* 7 */ T(7, false) - /* 8 */ T(T_LIM, true) - /* 9 */ T(T_LIM-1, true) - /* 10 */ T(T_LIM-1, false) - /* 11 */ T(T_LIM-2, true) - /* 12 */ T(2, true) - /* 13 */ T(2, false) - /* 14 */ T(T_LIM+16, true) - /* 15 */ T(3, false) - /* 16 */ T(T_LIM+16, false) - /* 17 */ T(T_LIM*4, true) - /* 18 */ T(T_LIM*4-(T_LIM-1), true) - /* 19 */ T(10, false) - /* 20 */ T(T_LIM*4-T_LIM, false) - /* 21 */ T(T_LIM*4-(T_LIM+1), false) - /* 22 */ T(T_LIM*4-(T_LIM-2), true) - /* 23 */ T(T_LIM*4+1-T_LIM, false) - /* 24 */ T(0, false) - /* 25 */ T(RejectAfterMessages, false) - /* 26 */ T(RejectAfterMessages-1, true) - /* 27 */ T(RejectAfterMessages, false) - /* 28 */ T(RejectAfterMessages-1, false) - /* 29 */ T(RejectAfterMessages-2, true) - /* 30 */ T(RejectAfterMessages+1, false) - /* 31 */ T(RejectAfterMessages+2, false) - /* 32 */ T(RejectAfterMessages-2, false) - /* 33 */ T(RejectAfterMessages-3, true) - /* 34 */ T(0, false) - - t.Log("Bulk test 1") - filter.Init() - testNumber = 0 - for i := uint64(1); i <= CounterWindowSize; i++ { - T(i, true) - } - T(0, true) - T(0, false) - - t.Log("Bulk test 2") - filter.Init() - testNumber = 0 - for i := uint64(2); i <= CounterWindowSize+1; i++ { - T(i, true) - } - T(1, true) - T(0, false) - - t.Log("Bulk test 3") - filter.Init() - testNumber = 0 - for i := CounterWindowSize + 1; i > 0; i-- { - T(i, true) - } - - t.Log("Bulk test 4") - filter.Init() - testNumber = 0 - for i := CounterWindowSize + 2; i > 1; i-- { - T(i, true) - } - T(0, false) - - t.Log("Bulk test 5") - filter.Init() - testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { - T(i, true) - } - T(CounterWindowSize+1, true) - T(0, false) - - t.Log("Bulk test 6") - filter.Init() - testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { - T(i, true) - } - T(0, true) - T(CounterWindowSize+1, true) -} diff --git a/src/routing.go b/src/routing.go deleted file mode 100644 index 2a2e237..0000000 --- a/src/routing.go +++ /dev/null @@ -1,65 +0,0 @@ -package main - -import ( - "errors" - "net" - "sync" -) - -type RoutingTable struct { - IPv4 *Trie - IPv6 *Trie - mutex sync.RWMutex -} - -func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { - table.mutex.RLock() - defer table.mutex.RUnlock() - - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.AllowedIPs(peer, allowed) - allowed = table.IPv6.AllowedIPs(peer, allowed) - return allowed -} - -func (table *RoutingTable) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil -} - -func (table *RoutingTable) RemovePeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = table.IPv4.RemovePeer(peer) - table.IPv6 = table.IPv6.RemovePeer(peer) -} - -func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.Insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.Insert(ip, cidr, peer) - default: - panic(errors.New("Inserting unknown address type")) - } -} - -func (table *RoutingTable) LookupIPv4(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv4.Lookup(address) -} - -func (table *RoutingTable) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.Lookup(address) -} diff --git a/src/send.go b/src/send.go deleted file mode 100644 index 7488d3a..0000000 --- a/src/send.go +++ /dev/null @@ -1,362 +0,0 @@ -package main - -import ( - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" - "sync" - "sync/atomic" - "time" -) - -/* Outbound flow - * - * 1. TUN queue - * 2. Routing (sequential) - * 3. Nonce assignment (sequential) - * 4. Encryption (parallel) - * 5. Transmission (sequential) - * - * The functions in this file occur (roughly) in the order in - * which the packets are processed. - * - * Locking, Producers and Consumers - * - * The order of packets (per peer) must be maintained, - * but encryption of packets happen out-of-order: - * - * The sequential consumers will attempt to take the lock, - * workers release lock when they have completed work (encryption) on the packet. - * - * If the element is inserted into the "encryption queue", - * the content is preceded by enough "junk" to contain the transport header - * (to allow the construction of transport messages in-place) - */ - -type QueueOutboundElement struct { - dropped int32 - mutex sync.Mutex - buffer *[MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "buffer" (always!) - nonce uint64 // nonce for encryption - keyPair *KeyPair // key-pair for encryption - peer *Peer // related peer -} - -func (peer *Peer) FlushNonceQueue() { - elems := len(peer.queue.nonce) - for i := 0; i < elems; i++ { - select { - case <-peer.queue.nonce: - default: - return - } - } -} - -func (device *Device) NewOutboundElement() *QueueOutboundElement { - return &QueueOutboundElement{ - dropped: AtomicFalse, - buffer: device.pool.messageBuffers.Get().(*[MaxMessageSize]byte), - } -} - -func (elem *QueueOutboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func addToOutboundQueue( - queue chan *QueueOutboundElement, - element *QueueOutboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - old.Drop() - default: - } - } - } -} - -func addToEncryptionQueue( - queue chan *QueueOutboundElement, - element *QueueOutboundElement, -) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - // drop & release to potential consumer - old.Drop() - old.mutex.Unlock() - default: - } - } - } -} - -/* Reads packets from the TUN and inserts - * into nonce queue for peer - * - * Obs. Single instance per TUN device - */ -func (device *Device) RoutineReadFromTUN() { - - elem := device.NewOutboundElement() - - logDebug := device.log.Debug - logError := device.log.Error - - logDebug.Println("Routine, TUN Reader started") - - for { - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) - - if err != nil { - logError.Println("Failed to read packet from TUN device:", err) - device.Close() - return - } - - if size == 0 || size > MaxContentSize { - continue - } - - elem.packet = elem.buffer[offset : offset+size] - - // lookup peer - - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue - } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.routing.table.LookupIPv4(dst) - - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - continue - } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.routing.table.LookupIPv6(dst) - - default: - logDebug.Println("Received packet with unknown IP version") - } - - if peer == nil { - continue - } - - // insert into nonce/pre-handshake queue - - if peer.isRunning.Get() { - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - addToOutboundQueue(peer.queue.nonce, elem) - elem = device.NewOutboundElement() - } - } -} - -/* Queues packets when there is no handshake. - * Then assigns nonces to packets sequentially - * and creates "work" structs for workers - * - * Obs. A single instance per peer - */ -func (peer *Peer) RoutineNonce() { - var keyPair *KeyPair - - defer peer.routines.stopping.Done() - - device := peer.device - logDebug := device.log.Debug - logDebug.Println("Routine, nonce worker, started for peer", peer.String()) - - peer.routines.starting.Done() - - for { - NextPacket: - select { - case <-peer.routines.stop.Wait(): - return - - case elem := <-peer.queue.nonce: - - // wait for key pair - - for { - keyPair = peer.keyPairs.Current() - if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { - if time.Now().Sub(keyPair.created) < RejectAfterTime { - break - } - } - - peer.signal.handshakeBegin.Send() - - logDebug.Println("Awaiting key-pair for", peer.String()) - - select { - case <-peer.signal.newKeyPair.Wait(): - case <-peer.signal.flushNonceQueue.Wait(): - logDebug.Println("Clearing queue for", peer.String()) - peer.FlushNonceQueue() - goto NextPacket - case <-peer.routines.stop.Wait(): - return - } - } - - // populate work element - - elem.peer = peer - elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 - elem.keyPair = keyPair - elem.dropped = AtomicFalse - elem.mutex.Lock() - - // add to parallel and sequential queue - - addToEncryptionQueue(device.queue.encryption, elem) - addToOutboundQueue(peer.queue.outbound, elem) - } - } -} - -/* Encrypts the elements in the queue - * and marks them for sequential consumption (by releasing the mutex) - * - * Obs. One instance per core - */ -func (device *Device) RoutineEncryption() { - - var nonce [chacha20poly1305.NonceSize]byte - - logDebug := device.log.Debug - logDebug.Println("Routine, encryption worker, started") - - for { - - // fetch next element - - select { - case <-device.signal.stop.Wait(): - logDebug.Println("Routine, encryption worker, stopped") - return - - case elem := <-device.queue.encryption: - - // check if dropped - - if elem.IsDropped() { - continue - } - - // populate header fields - - header := elem.buffer[:MessageTransportHeaderSize] - - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] - - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - - // pad content to multiple of 16 - - mtu := int(atomic.LoadInt32(&device.tun.mtu)) - rem := len(elem.packet) % PaddingMultiple - if rem > 0 { - for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ { - elem.packet = append(elem.packet, 0) - } - } - - // encrypt content and release to consumer - - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keyPair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.mutex.Unlock() - } - } -} - -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { - - defer peer.routines.stopping.Done() - - device := peer.device - - logDebug := device.log.Debug - logDebug.Println("Routine, sequential sender, started for", peer.String()) - - peer.routines.starting.Done() - - for { - select { - - case <-peer.routines.stop.Wait(): - logDebug.Println( - "Routine, sequential sender, stopped for", peer.String()) - return - - case elem := <-peer.queue.outbound: - elem.mutex.Lock() - if elem.IsDropped() { - continue - } - - // send message and return buffer to pool - - length := uint64(len(elem.packet)) - err := peer.SendBuffer(elem.packet) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logDebug.Println("Failed to send authenticated packet to peer", peer.String()) - continue - } - atomic.AddUint64(&peer.stats.txBytes, length) - - // update timers - - peer.TimerAnyAuthenticatedPacketTraversal() - if len(elem.packet) != MessageKeepaliveSize { - peer.TimerDataSent() - } - peer.KeepKeyFreshSending() - } - } -} diff --git a/src/signal.go b/src/signal.go deleted file mode 100644 index 2cefad4..0000000 --- a/src/signal.go +++ /dev/null @@ -1,53 +0,0 @@ -package main - -type Signal struct { - enabled AtomicBool - C chan struct{} -} - -func NewSignal() (s Signal) { - s.C = make(chan struct{}, 1) - s.Enable() - return -} - -func (s *Signal) Disable() { - s.enabled.Set(false) - s.Clear() -} - -func (s *Signal) Enable() { - s.enabled.Set(true) -} - -/* Unblock exactly one listener - */ -func (s *Signal) Send() { - if s.enabled.Get() { - select { - case s.C <- struct{}{}: - default: - } - } -} - -/* Clear the signal if already fired - */ -func (s Signal) Clear() { - select { - case <-s.C: - default: - } -} - -/* Unblocks all listeners (forever) - */ -func (s Signal) Broadcast() { - close(s.C) -} - -/* Wait for the signal - */ -func (s Signal) Wait() chan struct{} { - return s.C -} diff --git a/src/tai64.go b/src/tai64.go deleted file mode 100644 index 2299a37..0000000 --- a/src/tai64.go +++ /dev/null @@ -1,28 +0,0 @@ -package main - -import ( - "bytes" - "encoding/binary" - "time" -) - -const ( - TAI64NBase = uint64(4611686018427387914) - TAI64NSize = 12 -) - -type TAI64N [TAI64NSize]byte - -func Timestamp() TAI64N { - var tai64n TAI64N - now := time.Now() - secs := TAI64NBase + uint64(now.Unix()) - nano := uint32(now.UnixNano()) - binary.BigEndian.PutUint64(tai64n[:], secs) - binary.BigEndian.PutUint32(tai64n[8:], nano) - return tai64n -} - -func (t1 *TAI64N) After(t2 TAI64N) bool { - return bytes.Compare(t1[:], t2[:]) > 0 -} diff --git a/src/tests/netns.sh b/src/tests/netns.sh deleted file mode 100755 index 02d428b..0000000 --- a/src/tests/netns.sh +++ /dev/null @@ -1,425 +0,0 @@ -#!/bin/bash - -# Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. - -# This script tests the below topology: -# -# ┌─────────────────────┐ ┌──────────────────────────────────┐ ┌─────────────────────┐ -# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ -# │ │ │ │ │ │ -# │┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐│ -# ││ wg1 │───────────┼───┼────────────│ lo │────────────┼───┼───────────│ wg2 ││ -# │├────────┴──────────┐│ │ ┌───────┴────────┴────────┐ │ │┌──────────┴────────┤│ -# ││192.168.241.1/24 ││ │ │(ns1) (ns2) │ │ ││192.168.241.2/24 ││ -# ││fd00::1/24 ││ │ │127.0.0.1:1 127.0.0.1:2│ │ ││fd00::2/24 ││ -# │└───────────────────┘│ │ │[::]:1 [::]:2 │ │ │└───────────────────┘│ -# └─────────────────────┘ │ └─────────────────────────┘ │ └─────────────────────┘ -# └──────────────────────────────────┘ -# -# After the topology is prepared we run a series of TCP/UDP iperf3 tests between the -# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1 -# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further -# details on how this is accomplished. - -# This code is ported to the WireGuard-Go directly from the kernel project. -# -# Please ensure that you have installed the newest version of the WireGuard -# tools from the WireGuard project and before running these tests as: -# -# ./netns.sh <path to wireguard-go> - -set -e - -exec 3>&1 -export WG_HIDE_KEYS=never -netns0="wg-test-$$-0" -netns1="wg-test-$$-1" -netns2="wg-test-$$-2" -program=$1 -export LOG_LEVEL="info" - -pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } -pp() { pretty "" "$*"; "$@"; } -maybe_exec() { if [[ $BASHPID -eq $$ ]]; then "$@"; else exec "$@"; fi; } -n0() { pretty 0 "$*"; maybe_exec ip netns exec $netns0 "$@"; } -n1() { pretty 1 "$*"; maybe_exec ip netns exec $netns1 "$@"; } -n2() { pretty 2 "$*"; maybe_exec ip netns exec $netns2 "$@"; } -ip0() { pretty 0 "ip $*"; ip -n $netns0 "$@"; } -ip1() { pretty 1 "ip $*"; ip -n $netns1 "$@"; } -ip2() { pretty 2 "ip $*"; ip -n $netns2 "$@"; } -sleep() { read -t "$1" -N 0 || true; } -waitiperf() { pretty "${1//*-}" "wait for iperf:5201"; while [[ $(ss -N "$1" -tlp 'sport = 5201') != *iperf3* ]]; do sleep 0.1; done; } -waitncatudp() { pretty "${1//*-}" "wait for udp:1111"; while [[ $(ss -N "$1" -ulp 'sport = 1111') != *ncat* ]]; do sleep 0.1; done; } -waitiface() { pretty "${1//*-}" "wait for $2 to come up"; ip netns exec "$1" bash -c "while [[ \$(< \"/sys/class/net/$2/operstate\") != up ]]; do read -t .1 -N 0 || true; done;"; } - -cleanup() { - set +e - exec 2>/dev/null - printf "$orig_message_cost" > /proc/sys/net/core/message_cost - ip0 link del dev wg1 - ip1 link del dev wg1 - ip2 link del dev wg1 - local to_kill="$(ip netns pids $netns0) $(ip netns pids $netns1) $(ip netns pids $netns2)" - [[ -n $to_kill ]] && kill $to_kill - pp ip netns del $netns1 - pp ip netns del $netns2 - pp ip netns del $netns0 - exit -} - -orig_message_cost="$(< /proc/sys/net/core/message_cost)" -trap cleanup EXIT -printf 0 > /proc/sys/net/core/message_cost - -ip netns del $netns0 2>/dev/null || true -ip netns del $netns1 2>/dev/null || true -ip netns del $netns2 2>/dev/null || true -pp ip netns add $netns0 -pp ip netns add $netns1 -pp ip netns add $netns2 -ip0 link set up dev lo - -# ip0 link add dev wg1 type wireguard -n0 $program wg1 -ip0 link set wg1 netns $netns1 - -# ip0 link add dev wg1 type wireguard -n0 $program wg2 -ip0 link set wg2 netns $netns2 - -key1="$(pp wg genkey)" -key2="$(pp wg genkey)" -pub1="$(pp wg pubkey <<<"$key1")" -pub2="$(pp wg pubkey <<<"$key2")" -psk="$(pp wg genpsk)" -[[ -n $key1 && -n $key2 && -n $psk ]] - -configure_peers() { - - ip1 addr add 192.168.241.1/24 dev wg1 - ip1 addr add fd00::1/24 dev wg1 - - ip2 addr add 192.168.241.2/24 dev wg2 - ip2 addr add fd00::2/24 dev wg2 - - n0 wg set wg1 \ - private-key <(echo "$key1") \ - listen-port 10000 \ - peer "$pub2" \ - preshared-key <(echo "$psk") \ - allowed-ips 192.168.241.2/32,fd00::2/128 - n0 wg set wg2 \ - private-key <(echo "$key2") \ - listen-port 20000 \ - peer "$pub1" \ - preshared-key <(echo "$psk") \ - allowed-ips 192.168.241.1/32,fd00::1/128 - - n0 wg showconf wg1 - n0 wg showconf wg2 - - ip1 link set up dev wg1 - ip2 link set up dev wg2 - sleep 1 -} -configure_peers - -tests() { - # Ping over IPv4 - n2 ping -c 10 -f -W 1 192.168.241.1 - n1 ping -c 10 -f -W 1 192.168.241.2 - - # Ping over IPv6 - n2 ping6 -c 10 -f -W 1 fd00::1 - n1 ping6 -c 10 -f -W 1 fd00::2 - - # TCP over IPv4 - n2 iperf3 -s -1 -B 192.168.241.2 & - waitiperf $netns2 - n1 iperf3 -Z -n 1G -c 192.168.241.2 - - # TCP over IPv6 - n1 iperf3 -s -1 -B fd00::1 & - waitiperf $netns1 - n2 iperf3 -Z -n 1G -c fd00::1 - - # UDP over IPv4 - n1 iperf3 -s -1 -B 192.168.241.1 & - waitiperf $netns1 - n2 iperf3 -Z -n 1G -b 0 -u -c 192.168.241.1 - - # UDP over IPv6 - n2 iperf3 -s -1 -B fd00::2 & - waitiperf $netns2 - n1 iperf3 -Z -n 1G -b 0 -u -c fd00::2 -} - -[[ $(ip1 link show dev wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" -big_mtu=$(( 34816 - 1500 + $orig_mtu )) - -# Test using IPv4 as outer transport -n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 -n0 wg set wg2 peer "$pub1" endpoint 127.0.0.1:10000 - -# Before calling tests, we first make sure that the stats counters are working -n2 ping -c 10 -f -W 1 192.168.241.1 -{ read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip2 -stats link show dev wg2) -ip2 -stats link show dev wg2 -n0 wg show -[[ $rx_bytes -ge 840 && $tx_bytes -ge 880 && $rx_bytes -lt 2500 && $rx_bytes -lt 2500 ]] -echo "counters working" -tests -ip1 link set wg1 mtu $big_mtu -ip2 link set wg2 mtu $big_mtu -tests - -ip1 link set wg1 mtu $orig_mtu -ip2 link set wg2 mtu $orig_mtu - -# Test using IPv6 as outer transport -n0 wg set wg1 peer "$pub2" endpoint [::1]:20000 -n0 wg set wg2 peer "$pub1" endpoint [::1]:10000 -tests -ip1 link set wg1 mtu $big_mtu -ip2 link set wg2 mtu $big_mtu -tests - -ip1 link set wg1 mtu $orig_mtu -ip2 link set wg2 mtu $orig_mtu - -# Test using IPv4 that roaming works -ip0 -4 addr del 127.0.0.1/8 dev lo -ip0 -4 addr add 127.212.121.99/8 dev lo -n0 wg set wg1 listen-port 9999 -n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000 -n1 ping6 -W 1 -c 1 fd00::2 -[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]] - -# Test using IPv6 that roaming works -n1 wg set wg1 listen-port 9998 -n1 wg set wg1 peer "$pub2" endpoint [::1]:20000 -n1 ping -W 1 -c 1 192.168.241.2 -[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]] - -# Test that crypto-RP filter works -n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24 -exec 4< <(n1 ncat -l -u -p 1111) -nmap_pid=$! -waitncatudp $netns1 -n2 ncat -u 192.168.241.1 1111 <<<"X" -read -r -N 1 -t 1 out <&4 && [[ $out == "X" ]] -kill $nmap_pid -more_specific_key="$(pp wg genkey | pp wg pubkey)" -n0 wg set wg1 peer "$more_specific_key" allowed-ips 192.168.241.2/32 -n0 wg set wg2 listen-port 9997 -exec 4< <(n1 ncat -l -u -p 1111) -nmap_pid=$! -waitncatudp $netns1 -n2 ncat -u 192.168.241.1 1111 <<<"X" -! read -r -N 1 -t 1 out <&4 -kill $nmap_pid -n0 wg set wg1 peer "$more_specific_key" remove -[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]] - -ip1 link del wg1 -ip2 link del wg2 - -# Test using NAT. We now change the topology to this: -# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────────────┐ ┌────────────────────────────────────────┐ -# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ -# │ │ │ │ │ │ -# │ ┌─────┐ ┌─────┐ │ │ ┌──────┐ ┌──────┐ │ │ ┌─────┐ ┌─────┐ │ -# │ │ wg1 │─────────────│vethc│───────────┼────┼────│vethrc│ │vethrs│──────────────┼─────┼──│veths│────────────│ wg2 │ │ -# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├──────┴─────────┐ ├──────┴────────────┐ │ │ ├─────┴──────────┐ ├─────┴──────────┐ │ -# │ │192.168.241.1/24│ │192.168.1.100/24││ │ │192.168.1.100/24│ │10.0.0.1/24 │ │ │ │10.0.0.100/24 │ │192.168.241.2/24│ │ -# │ │fd00::1/24 │ │ ││ │ │ │ │SNAT:192.168.1.0/24│ │ │ │ │ │fd00::2/24 │ │ -# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └───────────────────┘ │ │ └────────────────┘ └────────────────┘ │ -# └────────────────────────────────────────┘ └────────────────────────────────────────────────┘ └────────────────────────────────────────┘ - -# ip1 link add dev wg1 type wireguard -# ip2 link add dev wg1 type wireguard - -n1 $program wg1 -n2 $program wg2 - -configure_peers - -ip0 link add vethrc type veth peer name vethc -ip0 link add vethrs type veth peer name veths -ip0 link set vethc netns $netns1 -ip0 link set veths netns $netns2 -ip0 link set vethrc up -ip0 link set vethrs up -ip0 addr add 192.168.1.1/24 dev vethrc -ip0 addr add 10.0.0.1/24 dev vethrs -ip1 addr add 192.168.1.100/24 dev vethc -ip1 link set vethc up -ip1 route add default via 192.168.1.1 -ip2 addr add 10.0.0.100/24 dev veths -ip2 link set veths up -waitiface $netns0 vethrc -waitiface $netns0 vethrs -waitiface $netns1 vethc -waitiface $netns2 veths - -n0 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward' -n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout' -n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream' -n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to 10.0.0.1 - -n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1 -n1 ping -W 1 -c 1 192.168.241.2 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] -# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). -pp sleep 3 -n2 ping -W 1 -c 1 192.168.241.1 - -n0 iptables -t nat -F -ip0 link del vethrc -ip0 link del vethrs -ip1 link del wg1 -ip2 link del wg2 - -# Test that saddr routing is sticky but not too sticky, changing to this topology: -# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ -# │ $ns1 namespace │ │ $ns2 namespace │ -# │ │ │ │ -# │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ -# │ │ wg1 │─────────────│veth1│───────────┼────┼──│veth2│────────────│ wg2 │ │ -# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├─────┴──────────┐ ├─────┴──────────┐ │ -# │ │192.168.241.1/24│ │10.0.0.1/24 ││ │ │10.0.0.2/24 │ │192.168.241.2/24│ │ -# │ │fd00::1/24 │ │fd00:aa::1/96 ││ │ │fd00:aa::2/96 │ │fd00::2/24 │ │ -# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └────────────────┘ │ -# └────────────────────────────────────────┘ └────────────────────────────────────────┘ - -# ip1 link add dev wg1 type wireguard -# ip2 link add dev wg1 type wireguard -n1 $program wg1 -n2 $program wg2 - -configure_peers - -ip1 link add veth1 type veth peer name veth2 -ip1 link set veth2 netns $netns2 -n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth1/accept_dad' -n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth2/accept_dad' -n1 bash -c 'printf 1 > /proc/sys/net/ipv4/conf/veth1/promote_secondaries' - -# First we check that we aren't overly sticky and can fall over to new IPs when old ones are removed -ip1 addr add 10.0.0.1/24 dev veth1 -ip1 addr add fd00:aa::1/96 dev veth1 -ip2 addr add 10.0.0.2/24 dev veth2 -ip2 addr add fd00:aa::2/96 dev veth2 -ip1 link set veth1 up -ip2 link set veth2 up -waitiface $netns1 veth1 -waitiface $netns2 veth2 -n0 wg set wg1 peer "$pub2" endpoint 10.0.0.2:20000 -n1 ping -W 1 -c 1 192.168.241.2 -ip1 addr add 10.0.0.10/24 dev veth1 -ip1 addr del 10.0.0.1/24 dev veth1 -n1 ping -W 1 -c 1 192.168.241.2 -n0 wg set wg1 peer "$pub2" endpoint [fd00:aa::2]:20000 -n1 ping -W 1 -c 1 192.168.241.2 -ip1 addr add fd00:aa::10/96 dev veth1 -ip1 addr del fd00:aa::1/96 dev veth1 -n1 ping -W 1 -c 1 192.168.241.2 - -# Now we show that we can successfully do reply to sender routing -ip1 link set veth1 down -ip2 link set veth2 down -ip1 addr flush dev veth1 -ip2 addr flush dev veth2 -ip1 addr add 10.0.0.1/24 dev veth1 -ip1 addr add 10.0.0.2/24 dev veth1 -ip1 addr add fd00:aa::1/96 dev veth1 -ip1 addr add fd00:aa::2/96 dev veth1 -ip2 addr add 10.0.0.3/24 dev veth2 -ip2 addr add fd00:aa::3/96 dev veth2 -ip1 link set veth1 up -ip2 link set veth2 up -waitiface $netns1 veth1 -waitiface $netns2 veth2 -n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]] -n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]] -n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]] -n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000 -n2 ping -W 1 -c 1 192.168.241.1 -[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]] - -ip1 link del veth1 -ip1 link del wg1 -ip2 link del wg2 - -# Test that Netlink/IPC is working properly by doing things that usually cause split responses - -n0 $program wg0 -sleep 5 -config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) -for a in {1..255}; do - for b in {0..255}; do - config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) - done -done -n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") -i=0 -for ip in $(n0 wg show wg0 allowed-ips); do - ((++i)) -done -((i == 255*256*2+1)) -ip0 link del wg0 - -n0 $program wg0 -config=( "[Interface]" "PrivateKey=$(wg genkey)" ) -for a in {1..40}; do - config+=( "[Peer]" "PublicKey=$(wg genkey)" ) - for b in {1..52}; do - config+=( "AllowedIPs=$a.$b.0.0/16" ) - done -done -n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") -i=0 -while read -r line; do - j=0 - for ip in $line; do - ((++j)) - done - ((j == 53)) - ((++i)) -done < <(n0 wg show wg0 allowed-ips) -((i == 40)) -ip0 link del wg0 - -n0 $program wg0 -config=( ) -for i in {1..29}; do - config+=( "[Peer]" "PublicKey=$(wg genkey)" ) -done -config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" ) -n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") -n0 wg showconf wg0 > /dev/null -ip0 link del wg0 - -! n0 wg show doesnotexist || false - -declare -A objects -while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do - [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue - objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" -done < /dev/kmsg -alldeleted=1 -for object in "${!objects[@]}"; do - if [[ ${objects["$object"]} != *createddestroyed ]]; then - echo "Error: $object: merely ${objects["$object"]}" >&3 - alldeleted=0 - fi -done -[[ $alldeleted -eq 1 ]] -pretty "" "Objects that were created were also destroyed." diff --git a/src/timer.go b/src/timer.go deleted file mode 100644 index f00ca49..0000000 --- a/src/timer.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "time" -) - -type Timer struct { - pending AtomicBool - timer *time.Timer -} - -/* Starts the timer if not already pending - */ -func (t *Timer) Start(dur time.Duration) bool { - set := t.pending.Swap(true) - if !set { - t.timer.Reset(dur) - return true - } - return false -} - -/* Stops the timer - */ -func (t *Timer) Stop() { - set := t.pending.Swap(true) - if set { - t.timer.Stop() - select { - case <-t.timer.C: - default: - } - } - t.pending.Set(false) -} - -func (t *Timer) Pending() bool { - return t.pending.Get() -} - -func (t *Timer) Reset(dur time.Duration) { - t.pending.Set(false) - t.Start(dur) -} - -func (t *Timer) Wait() <-chan time.Time { - return t.timer.C -} - -func NewTimer() (t Timer) { - t.pending.Set(false) - t.timer = time.NewTimer(0) - t.timer.Stop() - select { - case <-t.timer.C: - default: - } - return -} diff --git a/src/timers.go b/src/timers.go deleted file mode 100644 index 7092688..0000000 --- a/src/timers.go +++ /dev/null @@ -1,346 +0,0 @@ -package main
-
-import (
- "bytes"
- "encoding/binary"
- "math/rand"
- "sync/atomic"
- "time"
-)
-
-/* NOTE:
- * Notion of validity
- *
- *
- */
-
-/* Called when a new authenticated message has been send
- *
- */
-func (peer *Peer) KeepKeyFreshSending() {
- kp := peer.keyPairs.Current()
- if kp == nil {
- return
- }
- nonce := atomic.LoadUint64(&kp.sendNonce)
- if nonce > RekeyAfterMessages {
- peer.signal.handshakeBegin.Send()
- }
- if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
- peer.signal.handshakeBegin.Send()
- }
-}
-
-/* Called when a new authenticated message has been received
- *
- * NOTE: Not thread safe, but called by sequential receiver!
- */
-func (peer *Peer) KeepKeyFreshReceiving() {
- if peer.timer.sendLastMinuteHandshake {
- return
- }
- kp := peer.keyPairs.Current()
- if kp == nil {
- return
- }
- if !kp.isInitiator {
- return
- }
- nonce := atomic.LoadUint64(&kp.sendNonce)
- send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
- if send {
- // do a last minute attempt at initiating a new handshake
- peer.timer.sendLastMinuteHandshake = true
- peer.signal.handshakeBegin.Send()
- }
-}
-
-/* Queues a keep-alive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepAlive() bool {
- if len(peer.queue.nonce) != 0 {
- return false
- }
- elem := peer.device.NewOutboundElement()
- elem.packet = nil
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
- }
-}
-
-/* Event:
- * Sent non-empty (authenticated) transport message
- */
-func (peer *Peer) TimerDataSent() {
- peer.timer.keepalivePassive.Stop()
- peer.timer.handshakeNew.Start(NewHandshakeTime)
-}
-
-/* Event:
- * Received non-empty (authenticated) transport message
- *
- * Action:
- * Set a timer to confirm the message using a keep-alive (if not already set)
- */
-func (peer *Peer) TimerDataReceived() {
- if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) {
- peer.timer.needAnotherKeepalive = true
- }
-}
-
-/* Event:
- * Any (authenticated) packet received
- */
-func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- peer.timer.handshakeNew.Stop()
-}
-
-/* Event:
- * Any authenticated packet send / received.
- *
- * Action:
- * Push persistent keep-alive into the future
- */
-func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
- interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
- if interval > 0 {
- duration := time.Duration(interval) * time.Second
- peer.timer.keepalivePersistent.Reset(duration)
- }
-}
-
-/* Called after successfully completing a handshake.
- * i.e. after:
- *
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-func (peer *Peer) TimerHandshakeComplete() {
- peer.signal.handshakeCompleted.Send()
- peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
-}
-
-/* Event:
- * An ephemeral key is generated
- *
- * i.e. after:
- *
- * CreateMessageInitiation
- * CreateMessageResponse
- *
- * Action:
- * Schedule the deletion of all key material
- * upon failure to complete a handshake
- */
-func (peer *Peer) TimerEphemeralKeyCreated() {
- peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
-}
-
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
- // temporarily disable the handshake complete signal
-
- peer.signal.handshakeCompleted.Disable()
-
- // create initiation message
-
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- return err
- }
-
- // marshal handshake message
-
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- peer.TimerAnyAuthenticatedPacketTraversal()
-
- err = peer.SendBuffer(packet)
- if err == nil {
- peer.signal.handshakeCompleted.Enable()
- }
-
- // set timeout
-
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
-
- peer.timer.keepalivePassive.Stop()
- peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
-
- return err
-}
-
-func (peer *Peer) RoutineTimerHandler() {
-
- defer peer.routines.stopping.Done()
-
- device := peer.device
-
- logInfo := device.log.Info
- logDebug := device.log.Debug
- logDebug.Println("Routine, timer handler, started for peer", peer.String())
-
- // reset all timers
-
- peer.timer.keepalivePassive.Stop()
- peer.timer.handshakeDeadline.Stop()
- peer.timer.handshakeTimeout.Stop()
- peer.timer.handshakeNew.Stop()
- peer.timer.zeroAllKeys.Stop()
-
- interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
- if interval > 0 {
- duration := time.Duration(interval) * time.Second
- peer.timer.keepalivePersistent.Reset(duration)
- }
-
- // signal synchronised setup complete
-
- peer.routines.starting.Done()
-
- // handle timer events
-
- for {
- select {
-
- /* stopping */
-
- case <-peer.routines.stop.Wait():
- return
-
- /* timers */
-
- // keep-alive
-
- case <-peer.timer.keepalivePersistent.Wait():
-
- interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
- if interval > 0 {
- logDebug.Println("Sending keep-alive to", peer.String())
- peer.timer.keepalivePassive.Stop()
- peer.SendKeepAlive()
- }
-
- case <-peer.timer.keepalivePassive.Wait():
-
- logDebug.Println("Sending keep-alive to", peer.String())
-
- peer.SendKeepAlive()
-
- if peer.timer.needAnotherKeepalive {
- peer.timer.needAnotherKeepalive = false
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
- }
-
- // clear key material timer
-
- case <-peer.timer.zeroAllKeys.Wait():
-
- logDebug.Println("Clearing all key material for", peer.String())
-
- hs := &peer.handshake
- hs.mutex.Lock()
-
- kp := &peer.keyPairs
- kp.mutex.Lock()
-
- // remove key-pairs
-
- if kp.previous != nil {
- device.DeleteKeyPair(kp.previous)
- kp.previous = nil
- }
- if kp.current != nil {
- device.DeleteKeyPair(kp.current)
- kp.current = nil
- }
- if kp.next != nil {
- device.DeleteKeyPair(kp.next)
- kp.next = nil
- }
- kp.mutex.Unlock()
-
- // zero out handshake
-
- device.indices.Delete(hs.localIndex)
- hs.Clear()
- hs.mutex.Unlock()
-
- // handshake timers
-
- case <-peer.timer.handshakeNew.Wait():
- logInfo.Println("Retrying handshake with", peer.String())
- peer.signal.handshakeBegin.Send()
-
- case <-peer.timer.handshakeTimeout.Wait():
-
- // clear source (in case this is causing problems)
-
- peer.mutex.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.mutex.Unlock()
-
- // send new handshake
-
- err := peer.sendNewHandshake()
- if err != nil {
- logInfo.Println(
- "Failed to send handshake to peer:", peer.String(), "(", err, ")")
- }
-
- case <-peer.timer.handshakeDeadline.Wait():
-
- // clear all queued packets and stop keep-alive
-
- logInfo.Println(
- "Handshake negotiation timed out for:", peer.String())
-
- peer.signal.flushNonceQueue.Send()
- peer.timer.keepalivePersistent.Stop()
- peer.signal.handshakeBegin.Enable()
-
- /* signals */
-
- case <-peer.signal.handshakeBegin.Wait():
-
- peer.signal.handshakeBegin.Disable()
-
- err := peer.sendNewHandshake()
- if err != nil {
- logInfo.Println(
- "Failed to send handshake to peer:", peer.String(), "(", err, ")")
- }
-
- peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
-
- case <-peer.signal.handshakeCompleted.Wait():
-
- logInfo.Println(
- "Handshake completed for:", peer.String())
-
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
-
- peer.timer.handshakeTimeout.Stop()
- peer.timer.handshakeDeadline.Stop()
- peer.signal.handshakeBegin.Enable()
-
- peer.timer.sendLastMinuteHandshake = false
- }
- }
-}
diff --git a/src/trie.go b/src/trie.go deleted file mode 100644 index 405ffc3..0000000 --- a/src/trie.go +++ /dev/null @@ -1,228 +0,0 @@ -package main - -import ( - "errors" - "net" -) - -/* Binary trie - * - * The net.IPs used here are not formatted the - * same way as those created by the "net" functions. - * Here the IPs are slices of either 4 or 16 byte (not always 16) - * - * Synchronization done separately - * See: routing.go - */ - -type Trie struct { - cidr uint - child [2]*Trie - bits []byte - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -/* Finds length of matching prefix - * - * TODO: Only use during insertion (xor + prefix mask for lookup) - * Check out - * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) - * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match - * - * Assumption: - * len(ip1) == len(ip2) - * len(ip1) mod 4 = 0 - */ -func commonBits(ip1 []byte, ip2 []byte) uint { - var i uint - size := uint(len(ip1)) - - for i = 0; i < size; i++ { - v := ip1[i] ^ ip2[i] - if v != 0 { - v >>= 1 - if v == 0 { - return i*8 + 7 - } - - v >>= 1 - if v == 0 { - return i*8 + 6 - } - - v >>= 1 - if v == 0 { - return i*8 + 5 - } - - v >>= 1 - if v == 0 { - return i*8 + 4 - } - - v >>= 1 - if v == 0 { - return i*8 + 3 - } - - v >>= 1 - if v == 0 { - return i*8 + 2 - } - - v >>= 1 - if v == 0 { - return i*8 + 1 - } - return i * 8 - } - } - return i * 8 -} - -func (node *Trie) RemovePeer(p *Peer) *Trie { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].RemovePeer(p) - node.child[1] = node.child[1].RemovePeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - -func (node *Trie) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 -} - -func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { - - // at leaf - - if node == nil { - return &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].Insert(ip, cidr, peer) - return node - } - - // split node - - newNode := &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - cidr = min(cidr, common) - - // check for shorter prefix - - if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode - } - - // create new parent for node & newNode - - parent := &Trie{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent -} - -func (node *Trie) Lookup(ip net.IP) *Peer { - var found *Peer - size := uint(len(ip)) - for node != nil && commonBits(node.bits, ip) >= node.cidr { - if node.peer != nil { - found = node.peer - } - if node.bit_at_byte == size { - break - } - bit := node.choose(ip) - node = node.child[bit] - } - return found -} - -func (node *Trie) Count() uint { - if node == nil { - return 0 - } - l := node.child[0].Count() - r := node.child[1].Count() - return l + r -} - -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - var mask net.IPNet - mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) - if len(node.bits) == net.IPv4len { - mask.IP = net.IPv4( - node.bits[0], - node.bits[1], - node.bits[2], - node.bits[3], - ) - } else if len(node.bits) == net.IPv6len { - mask.IP = node.bits - } else { - panic(errors.New("bug: unexpected address length")) - } - results = append(results, mask) - } - results = node.child[0].AllowedIPs(p, results) - results = node.child[1].AllowedIPs(p, results) - return results -} diff --git a/src/trie_rand_test.go b/src/trie_rand_test.go deleted file mode 100644 index 840d269..0000000 --- a/src/trie_rand_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package main - -import ( - "math/rand" - "sort" - "testing" -) - -const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 -) - -type SlowNode struct { - peer *Peer - cidr uint - bits []byte -} - -type SlowRouter []*SlowNode - -func (r SlowRouter) Len() int { - return len(r) -} - -func (r SlowRouter) Less(i, j int) bool { - return r[i].cidr > r[j].cidr -} - -func (r SlowRouter) Swap(i, j int) { - r[i], r[j] = r[j], r[i] -} - -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { - for _, t := range r { - if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { - t.peer = peer - t.bits = addr - return r - } - } - r = append(r, &SlowNode{ - cidr: cidr, - bits: addr, - peer: peer, - }) - sort.Sort(r) - return r -} - -func (r SlowRouter) Lookup(addr []byte) *Peer { - for _, t := range r { - common := commonBits(t.bits, addr) - if common >= t.cidr { - return t.peer - } - } - return nil -} - -func TestTrieRandomIPv4(t *testing.T) { - var trie *Trie - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} - -func TestTrieRandomIPv6(t *testing.T) { - var trie *Trie - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} diff --git a/src/trie_test.go b/src/trie_test.go deleted file mode 100644 index 9d53df3..0000000 --- a/src/trie_test.go +++ /dev/null @@ -1,255 +0,0 @@ -package main - -import ( - "math/rand" - "net" - "testing" -) - -/* Todo: More comprehensive - */ - -type testPairCommonBits struct { - s1 []byte - s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *Trie) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) -} - -func TestCommonBits(t *testing.T) { - - tests := []testPairCommonBits{ - {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, - {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, - {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, - {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, - {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, - } - - for _, p := range tests { - v := commonBits(p.s1, p.s2) - if v != p.match { - t.Error( - "For slice", p.s1, p.s2, - "expected match", p.match, - ",but got", v, - ) - } - } -} - -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { - var trie *Trie - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < peerNumber; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < addressNumber; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber - trie = trie.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < b.N; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - trie.Lookup(addr[:]) - } -} - -func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv4len, b) -} - -func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv4len, b) -} - -func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv6len, b) -} - -func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv6len, b) -} - -/* Test ported from kernel implementation: - * selftest/routingtable.h - */ -func TestTrieIPv4(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *Trie - - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.Insert([]byte{a, b, c, d}, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) - if p != peer { - t.Error("Assert EQ failed") - } - } - - assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) - if p == peer { - t.Error("Assert NEQ failed") - } - } - - insert(a, 192, 168, 4, 0, 24) - insert(b, 192, 168, 4, 4, 32) - insert(c, 192, 168, 0, 0, 16) - insert(d, 192, 95, 5, 64, 27) - insert(c, 192, 95, 5, 65, 27) - insert(e, 0, 0, 0, 0, 0) - insert(g, 64, 15, 112, 0, 20) - insert(h, 64, 15, 123, 211, 25) - insert(a, 10, 0, 0, 0, 25) - insert(b, 10, 0, 0, 128, 25) - insert(a, 10, 1, 0, 0, 30) - insert(b, 10, 1, 0, 4, 30) - insert(c, 10, 1, 0, 8, 29) - insert(d, 10, 1, 0, 16, 29) - - assertEQ(a, 192, 168, 4, 20) - assertEQ(a, 192, 168, 4, 0) - assertEQ(b, 192, 168, 4, 4) - assertEQ(c, 192, 168, 200, 182) - assertEQ(c, 192, 95, 5, 68) - assertEQ(e, 192, 95, 5, 96) - assertEQ(g, 64, 15, 116, 26) - assertEQ(g, 64, 15, 127, 3) - - insert(a, 1, 0, 0, 0, 32) - insert(a, 64, 0, 0, 0, 32) - insert(a, 128, 0, 0, 0, 32) - insert(a, 192, 0, 0, 0, 32) - insert(a, 255, 0, 0, 0, 32) - - assertEQ(a, 1, 0, 0, 0) - assertEQ(a, 64, 0, 0, 0) - assertEQ(a, 128, 0, 0, 0) - assertEQ(a, 192, 0, 0, 0) - assertEQ(a, 255, 0, 0, 0) - - trie = trie.RemovePeer(a) - - assertNEQ(a, 1, 0, 0, 0) - assertNEQ(a, 64, 0, 0, 0) - assertNEQ(a, 128, 0, 0, 0) - assertNEQ(a, 192, 0, 0, 0) - assertNEQ(a, 255, 0, 0, 0) - - trie = nil - - insert(a, 192, 168, 0, 0, 16) - insert(a, 192, 168, 0, 0, 24) - - trie = trie.RemovePeer(a) - - assertNEQ(a, 192, 168, 0, 1) -} - -/* Test ported from kernel implementation: - * selftest/routingtable.h - */ -func TestTrieIPv6(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - f := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *Trie - - expand := func(a uint32) []byte { - var out [4]byte - out[0] = byte(a >> 24 & 0xff) - out[1] = byte(a >> 16 & 0xff) - out[2] = byte(a >> 8 & 0xff) - out[3] = byte(a & 0xff) - return out[:] - } - - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - trie = trie.Insert(addr, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d uint32) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - p := trie.Lookup(addr) - if p != peer { - t.Error("Assert EQ failed") - } - } - - insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) - insert(c, 0x26075300, 0x60006b00, 0, 0, 64) - insert(e, 0, 0, 0, 0, 0) - insert(f, 0, 0, 0, 0, 0) - insert(g, 0x24046800, 0, 0, 0, 32) - insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) - insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) - insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) - insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) - - assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) - assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) - assertEQ(f, 0x26075300, 0x60006b01, 0, 0) - assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) - assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0, 0) - assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) - assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) -} diff --git a/src/tun.go b/src/tun.go deleted file mode 100644 index 6259f33..0000000 --- a/src/tun.go +++ /dev/null @@ -1,58 +0,0 @@ -package main - -import ( - "os" - "sync/atomic" -) - -const DefaultMTU = 1420 - -type TUNEvent int - -const ( - TUNEventUp = 1 << iota - TUNEventDown - TUNEventMTUUpdate -) - -type TUNDevice interface { - File() *os.File // returns the file descriptor of the device - Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) - Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) - MTU() (int, error) // returns the MTU of the device - Name() string // returns the current name - Events() chan TUNEvent // returns a constant channel of events related to the device - Close() error // stops the device and closes the event channel -} - -func (device *Device) RoutineTUNEventReader() { - logInfo := device.log.Info - logError := device.log.Error - - for event := range device.tun.device.Events() { - if event&TUNEventMTUUpdate != 0 { - mtu, err := device.tun.device.MTU() - old := atomic.LoadInt32(&device.tun.mtu) - if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - } else if int(old) != mtu { - if mtu+MessageTransportSize > MaxMessageSize { - logInfo.Println("MTU updated:", mtu, "(too large)") - } else { - logInfo.Println("MTU updated:", mtu) - } - atomic.StoreInt32(&device.tun.mtu, int32(mtu)) - } - } - - if event&TUNEventUp != 0 && !device.isUp.Get() { - logInfo.Println("Interface set up") - device.Up() - } - - if event&TUNEventDown != 0 && device.isUp.Get() { - logInfo.Println("Interface set down") - device.Down() - } - } -} diff --git a/src/tun_darwin.go b/src/tun_darwin.go deleted file mode 100644 index 87f6af6..0000000 --- a/src/tun_darwin.go +++ /dev/null @@ -1,323 +0,0 @@ -/* Copyright (c) 2016, Song Gao <song@gao.io> - * All rights reserved. - * - * Code from https://github.com/songgao/water - */ - -package main - -import ( - "encoding/binary" - "errors" - "fmt" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" - "io" - "net" - "os" - "sync" - "time" - "unsafe" -) - -const utunControlName = "com.apple.net.utun_control" - -// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h -const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3 - -// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - -// NativeTUN is a hack to work around the first 4 bytes "packet -// information" because there doesn't seem to be an IFF_NO_PI for darwin. -type NativeTUN struct { - name string - f io.ReadWriteCloser - mtu int - - rMu sync.Mutex - rBuf []byte - - wMu sync.Mutex - wBuf []byte - - events chan TUNEvent - errors chan error -} - -var sockaddrCtlSize uintptr = 32 - -func CreateTUN(name string) (ifce TUNDevice, err error) { - ifIndex := -1 - fmt.Sscanf(name, "utun%d", &ifIndex) - if ifIndex < 0 { - return nil, fmt.Errorf("error parsing interface name %s, must be utun[0-9]+", name) - } - - fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) - - if err != nil { - return nil, fmt.Errorf("error in unix.Socket: %v", err) - } - - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} - - copy(ctlInfo.ctlName[:], []byte(utunControlName)) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(_CTLIOCGINFO), - uintptr(unsafe.Pointer(ctlInfo)), - ) - - if errno != 0 { - err = errno - return nil, fmt.Errorf("error in unix.Syscall(unix.SYS_IOTL, ...): %v", err) - } - - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: 2, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, - } - - scPointer := unsafe.Pointer(&sc) - - _, _, errno = unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(scPointer), - uintptr(sockaddrCtlSize), - ) - - if errno != 0 { - err = errno - return nil, fmt.Errorf("error in unix.RawSyscall(unix.SYS_CONNECT, ...): %v", err) - } - - // read (new) name of interface - - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(16) - - _, _, errno = unix.Syscall6( - unix.SYS_GETSOCKOPT, - uintptr(fd), - 2, /* #define SYSPROTO_CONTROL 2 */ - 2, /* #define UTUN_OPT_IFNAME 2 */ - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - - if errno != 0 { - err = errno - return nil, fmt.Errorf("error in unix.Syscall6(unix.SYS_GETSOCKOPT, ...): %v", err) - } - - device := &NativeTUN{ - name: string(ifName.name[:ifNameSize-1 /* -1 is for \0 */]), - f: os.NewFile(uintptr(fd), string(ifName.name[:])), - mtu: 1500, - events: make(chan TUNEvent, 10), - errors: make(chan error, 1), - } - - // start listener - - go func(native *NativeTUN) { - // TODO: Fix this very niave implementation - var ( - statusUp bool - statusMTU int - ) - - for ; ; time.Sleep(time.Second) { - intr, err := net.InterfaceByName(device.name) - if err != nil { - native.errors <- err - return - } - - // Up / Down event - up := (intr.Flags & net.FlagUp) != 0 - if up != statusUp && up { - native.events <- TUNEventUp - } - if up != statusUp && !up { - native.events <- TUNEventDown - } - statusUp = up - - // MTU changes - if intr.MTU != statusMTU { - native.events <- TUNEventMTUUpdate - } - statusMTU = intr.MTU - } - }(device) - - // set default MTU - - err = device.setMTU(DefaultMTU) - - return device, err -} - -var _ io.ReadWriteCloser = (*NativeTUN)(nil) - -func (t *NativeTUN) Events() chan TUNEvent { - return t.events -} - -func (t *NativeTUN) Read(to []byte) (int, error) { - t.rMu.Lock() - defer t.rMu.Unlock() - - if cap(t.rBuf) < len(to)+4 { - t.rBuf = make([]byte, len(to)+4) - } - t.rBuf = t.rBuf[:len(to)+4] - - n, err := t.f.Read(t.rBuf) - copy(to, t.rBuf[4:]) - return n - 4, err -} - -func (t *NativeTUN) Write(from []byte) (int, error) { - - if len(from) == 0 { - return 0, unix.EIO - } - - t.wMu.Lock() - defer t.wMu.Unlock() - - if cap(t.wBuf) < len(from)+4 { - t.wBuf = make([]byte, len(from)+4) - } - t.wBuf = t.wBuf[:len(from)+4] - - // determine the IP Family for the NULL L2 Header - - ipVer := from[0] >> 4 - if ipVer == ipv4.Version { - t.wBuf[3] = unix.AF_INET - } else if ipVer == ipv6.Version { - t.wBuf[3] = unix.AF_INET6 - } else { - return 0, errors.New("Unable to determine IP version from packet.") - } - - copy(t.wBuf[4:], from) - - n, err := t.f.Write(t.wBuf) - return n - 4, err -} - -func (t *NativeTUN) Close() error { - - // lock to make sure no read/write is in process. - - t.rMu.Lock() - defer t.rMu.Unlock() - - t.wMu.Lock() - defer t.wMu.Unlock() - - return t.f.Close() -} - -func (t *NativeTUN) Name() string { - return t.name -} - -func (t *NativeTUN) setMTU(n int) error { - - // open datagram socket - - var fd int - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [32]byte - copy(ifr[:], t.name) - binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return fmt.Errorf("Failed to set MTU on %s", t.name) - } - - return nil -} - -func (t *NativeTUN) MTU() (int, error) { - - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return 0, err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [64]byte - copy(ifr[:], t.name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return 0, fmt.Errorf("Failed to get MTU on %s", t.name) - } - - // convert result to signed 32-bit int - - val := binary.LittleEndian.Uint32(ifr[16:20]) - if val >= (1 << 31) { - return int(val-(1<<31)) - (1 << 31), nil - } - return int(val), nil -} diff --git a/src/tun_linux.go b/src/tun_linux.go deleted file mode 100644 index 9756169..0000000 --- a/src/tun_linux.go +++ /dev/null @@ -1,377 +0,0 @@ -package main - -/* Implementation of the TUN device interface for linux - */ - -import ( - "encoding/binary" - "errors" - "fmt" - "golang.org/x/net/ipv6" - "golang.org/x/sys/unix" - "net" - "os" - "strings" - "time" - "unsafe" -) - -// #include <string.h> -// #include <unistd.h> -// #include <net/if.h> -// #include <netinet/in.h> -// #include <linux/netlink.h> -// #include <linux/rtnetlink.h> -// -// /* Creates a netlink socket -// * listening to the RTMGRP_LINK multicast group -// */ -// -// int bind_rtmgrp() { -// int nl_sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); -// if (nl_sock < 0) -// return -1; -// -// struct sockaddr_nl addr; -// memset ((void *) &addr, 0, sizeof (addr)); -// addr.nl_family = AF_NETLINK; -// addr.nl_pid = getpid (); -// addr.nl_groups = RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR; -// -// if (bind(nl_sock, (struct sockaddr *) &addr, sizeof (addr)) < 0) -// return -1; -// -// return nl_sock; -// } -import "C" - -const ( - CloneDevicePath = "/dev/net/tun" - IFReqSize = unix.IFNAMSIZ + 64 -) - -type NativeTun struct { - fd *os.File - index int32 // if index - name string // name of interface - errors chan error // async error handling - events chan TUNEvent // device related events -} - -func (tun *NativeTun) File() *os.File { - return tun.fd -} - -func (tun *NativeTun) RoutineHackListener() { - /* This is needed for the detection to work across network namespaces - * If you are reading this and know a better method, please get in touch. - */ - fd := int(tun.fd.Fd()) - for { - _, err := unix.Write(fd, nil) - switch err { - case unix.EINVAL: - tun.events <- TUNEventUp - case unix.EIO: - tun.events <- TUNEventDown - default: - } - time.Sleep(time.Second / 10) - } -} - -func (tun *NativeTun) RoutineNetlinkListener() { - - sock := int(C.bind_rtmgrp()) - if sock < 0 { - tun.errors <- errors.New("Failed to create netlink event listener") - return - } - - for msg := make([]byte, 1<<16); ; { - - msgn, _, _, _, err := unix.Recvmsg(sock, msg[:], nil, 0) - if err != nil { - tun.errors <- fmt.Errorf("Failed to receive netlink message: %s", err.Error()) - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if int(hdr.Len) > len(remain) { - break - } - - switch hdr.Type { - case unix.NLMSG_DONE: - remain = []byte{} - - case unix.RTM_NEWLINK: - info := *(*unix.IfInfomsg)(unsafe.Pointer(&remain[unix.SizeofNlMsghdr])) - remain = remain[hdr.Len:] - - if info.Index != tun.index { - // not our interface - continue - } - - if info.Flags&unix.IFF_RUNNING != 0 { - tun.events <- TUNEventUp - } - - if info.Flags&unix.IFF_RUNNING == 0 { - tun.events <- TUNEventDown - } - - tun.events <- TUNEventMTUUpdate - - default: - remain = remain[hdr.Len:] - } - } - } -} - -func (tun *NativeTun) isUp() (bool, error) { - inter, err := net.InterfaceByName(tun.name) - return inter.Flags&net.FlagUp != 0, err -} - -func (tun *NativeTun) Name() string { - return tun.name -} - -func getDummySock() (int, error) { - return unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) -} - -func getIFIndex(name string) (int32, error) { - fd, err := getDummySock() - if err != nil { - return 0, err - } - - defer unix.Close(fd) - - var ifr [IFReqSize]byte - copy(ifr[:], name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFINDEX), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return 0, errno - } - - index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) - return toInt32(index), nil -} - -func (tun *NativeTun) setMTU(n int) error { - - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [IFReqSize]byte - copy(ifr[:], tun.name) - binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return errors.New("Failed to set MTU of TUN device") - } - - return nil -} - -func (tun *NativeTun) MTU() (int, error) { - - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return 0, err - } - - defer unix.Close(fd) - - // do ioctl call - - var ifr [IFReqSize]byte - copy(ifr[:], tun.name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return 0, errors.New("Failed to get MTU of TUN device") - } - - // convert result to signed 32-bit int - - val := binary.LittleEndian.Uint32(ifr[16:20]) - if val >= (1 << 31) { - return int(toInt32(val)), nil - } - return int(val), nil -} - -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - - if buff[4] == ipv6.Version<<4 { - buff[2] = 0x86 - buff[3] = 0xdd - } else { - buff[2] = 0x08 - buff[3] = 0x00 - } - - // write - - return tun.fd.Write(buff) -} - -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { - select { - case err := <-tun.errors: - return 0, err - default: - buff := buff[offset-4:] - n, err := tun.fd.Read(buff[:]) - if n < 4 { - return 0, err - } - return n - 4, err - } -} - -func (tun *NativeTun) Events() chan TUNEvent { - return tun.events -} - -func (tun *NativeTun) Close() error { - return nil -} - -func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) { - device := &NativeTun{ - fd: fd, - name: name, - events: make(chan TUNEvent, 5), - errors: make(chan error, 5), - } - - // start event listener - - var err error - device.index, err = getIFIndex(device.name) - if err != nil { - return nil, err - } - - go device.RoutineNetlinkListener() - // go device.RoutineHackListener() // cross namespace - - // set default MTU - - return device, device.setMTU(DefaultMTU) -} - -func CreateTUN(name string) (TUNDevice, error) { - - // open clone device - - fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0) - if err != nil { - return nil, err - } - - // create new device - - var ifr [IFReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - return nil, errors.New("Interface name too long") - } - copy(ifr[:], nameBytes) - binary.LittleEndian.PutUint16(ifr[16:], flags) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd.Fd()), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return nil, errno - } - - // read (new) name of interface - - newName := string(ifr[:]) - newName = newName[:strings.Index(newName, "\000")] - device := &NativeTun{ - fd: fd, - name: newName, - events: make(chan TUNEvent, 5), - errors: make(chan error, 5), - } - - // start event listener - - device.index, err = getIFIndex(device.name) - if err != nil { - return nil, err - } - - go device.RoutineNetlinkListener() - // go device.RoutineHackListener() // cross namespace - - // set default MTU - - return device, device.setMTU(DefaultMTU) -} diff --git a/src/tun_windows.go b/src/tun_windows.go deleted file mode 100644 index 0711032..0000000 --- a/src/tun_windows.go +++ /dev/null @@ -1,475 +0,0 @@ -package main - -import ( - "encoding/binary" - "errors" - "fmt" - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" - "net" - "sync" - "syscall" - "time" - "unsafe" -) - -/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version) - * - * https://github.com/OpenVPN/tap-windows - */ - -type NativeTUN struct { - fd windows.Handle - rl sync.Mutex - wl sync.Mutex - ro *windows.Overlapped - wo *windows.Overlapped - events chan TUNEvent - name string -} - -const ( - METHOD_BUFFERED = 0 - ComponentID = "tap0901" // tap0801 -) - -func ctl_code(device_type, function, method, access uint32) uint32 { - return (device_type << 16) | (access << 14) | (function << 2) | method -} - -func TAP_CONTROL_CODE(request, method uint32) uint32 { - return ctl_code(file_device_unknown, request, method, 0) -} - -var ( - errIfceNameNotFound = errors.New("Failed to find the name of interface") - - TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED) - TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED) - TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED) - TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED) - TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED) - TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED) - TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED) - - file_device_unknown = uint32(0x00000022) - nCreateEvent, - nResetEvent, - nGetOverlappedResult uintptr -) - -func init() { - k32, err := windows.LoadLibrary("kernel32.dll") - if err != nil { - panic("LoadLibrary " + err.Error()) - } - defer windows.FreeLibrary(k32) - nCreateEvent = getProcAddr(k32, "CreateEventW") - nResetEvent = getProcAddr(k32, "ResetEvent") - nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult") -} - -/* implementation of the read/write/closer interface */ - -func getProcAddr(lib windows.Handle, name string) uintptr { - addr, err := windows.GetProcAddress(lib, name) - if err != nil { - panic(name + " " + err.Error()) - } - return addr -} - -func resetEvent(h windows.Handle) error { - r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0) - if r == 0 { - return err - } - return nil -} - -func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) { - var n int - r, _, err := syscall.Syscall6( - nGetOverlappedResult, - 4, - uintptr(h), - uintptr(unsafe.Pointer(overlapped)), - uintptr(unsafe.Pointer(&n)), 1, 0, 0) - - if r == 0 { - return n, err - } - return n, nil -} - -func newOverlapped() (*windows.Overlapped, error) { - var overlapped windows.Overlapped - r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0) - if r == 0 { - return nil, err - } - overlapped.HEvent = windows.Handle(r) - return &overlapped, nil -} - -func (f *NativeTUN) Events() chan TUNEvent { - return f.events -} - -func (f *NativeTUN) Close() error { - return windows.Close(f.fd) -} - -func (f *NativeTUN) Write(b []byte) (int, error) { - f.wl.Lock() - defer f.wl.Unlock() - - if err := resetEvent(f.wo.HEvent); err != nil { - return 0, err - } - var n uint32 - err := windows.WriteFile(f.fd, b, &n, f.wo) - if err != nil && err != windows.ERROR_IO_PENDING { - return int(n), err - } - return getOverlappedResult(f.fd, f.wo) -} - -func (f *NativeTUN) Read(b []byte) (int, error) { - f.rl.Lock() - defer f.rl.Unlock() - - if err := resetEvent(f.ro.HEvent); err != nil { - return 0, err - } - var done uint32 - err := windows.ReadFile(f.fd, b, &done, f.ro) - if err != nil && err != windows.ERROR_IO_PENDING { - return int(done), err - } - return getOverlappedResult(f.fd, f.ro) -} - -func getdeviceid( - targetComponentId string, - targetDeviceName string, -) (deviceid string, err error) { - - getName := func(instanceId string) (string, error) { - path := fmt.Sprintf( - `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`, - instanceId, - ) - - key, err := registry.OpenKey( - registry.LOCAL_MACHINE, - path, - registry.READ, - ) - - if err != nil { - return "", err - } - defer key.Close() - - val, _, err := key.GetStringValue("Name") - key.Close() - return val, err - } - - getInstanceId := func(keyName string) (string, string, error) { - path := fmt.Sprintf( - `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`, - keyName, - ) - - key, err := registry.OpenKey( - registry.LOCAL_MACHINE, - path, - registry.READ, - ) - - if err != nil { - return "", "", err - } - defer key.Close() - - componentId, _, err := key.GetStringValue("ComponentId") - if err != nil { - return "", "", err - } - - instanceId, _, err := key.GetStringValue("NetCfgInstanceId") - - return componentId, instanceId, err - } - - // find list of all network devices - - k, err := registry.OpenKey( - registry.LOCAL_MACHINE, - `SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`, - registry.READ, - ) - - if err != nil { - return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err) - } - - defer k.Close() - - keys, err := k.ReadSubKeyNames(-1) - - if err != nil { - return "", err - } - - // look for matching component id and name - - var componentFound bool - - for _, v := range keys { - - componentId, instanceId, err := getInstanceId(v) - if err != nil || componentId != targetComponentId { - continue - } - - componentFound = true - - deviceName, err := getName(instanceId) - if err != nil || deviceName != targetDeviceName { - continue - } - - return instanceId, nil - } - - // provide a descriptive error message - - if componentFound { - return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName) - } - - return "", fmt.Errorf( - "Unable to find device in registry with ComponentId = %s, is tap-windows installed?", - targetComponentId, - ) -} - -// setStatus is used to bring up or bring down the interface -func setStatus(fd windows.Handle, status bool) error { - var code [4]byte - if status { - binary.LittleEndian.PutUint32(code[:], 1) - } - - var bytesReturned uint32 - rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) - return windows.DeviceIoControl( - fd, - TAP_IOCTL_SET_MEDIA_STATUS, - &code[0], - uint32(4), - &rdbbuf[0], - uint32(len(rdbbuf)), - &bytesReturned, - nil, - ) -} - -/* When operating in TUN mode we must assign an ip address & subnet to the device. - * - */ -func setTUN(fd windows.Handle, network string) error { - var bytesReturned uint32 - rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE) - localIP, remoteNet, err := net.ParseCIDR(network) - - if err != nil { - return fmt.Errorf("Failed to parse network CIDR in config, %v", err) - } - - if localIP.To4() == nil { - return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network) - } - - var param [12]byte - - copy(param[0:4], localIP.To4()) - copy(param[4:8], remoteNet.IP.To4()) - copy(param[8:12], remoteNet.Mask) - - return windows.DeviceIoControl( - fd, - TAP_IOCTL_CONFIG_TUN, - ¶m[0], - uint32(12), - &rdbbuf[0], - uint32(len(rdbbuf)), - &bytesReturned, - nil, - ) -} - -func (tun *NativeTUN) MTU() (int, error) { - var mtu [4]byte - var bytesReturned uint32 - err := windows.DeviceIoControl( - tun.fd, - TAP_IOCTL_GET_MTU, - &mtu[0], - uint32(len(mtu)), - &mtu[0], - uint32(len(mtu)), - &bytesReturned, - nil, - ) - val := binary.LittleEndian.Uint32(mtu[:]) - return int(val), err -} - -func (tun *NativeTUN) Name() string { - return tun.name -} - -func CreateTUN(name string) (TUNDevice, error) { - - // find the device in registry. - - deviceid, err := getdeviceid(ComponentID, name) - if err != nil { - return nil, err - } - path := "\\\\.\\Global\\" + deviceid + ".tap" - pathp, err := windows.UTF16PtrFromString(path) - if err != nil { - return nil, err - } - - // create TUN device - - handle, err := windows.CreateFile( - pathp, - windows.GENERIC_READ|windows.GENERIC_WRITE, - 0, - nil, - windows.OPEN_EXISTING, - windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED, - 0, - ) - - if err != nil { - return nil, err - } - - ro, err := newOverlapped() - if err != nil { - windows.Close(handle) - return nil, err - } - - wo, err := newOverlapped() - if err != nil { - windows.Close(handle) - return nil, err - } - - tun := &NativeTUN{ - fd: handle, - name: name, - ro: ro, - wo: wo, - events: make(chan TUNEvent, 5), - } - - // find addresses of interface - // TODO: fix this hack, the question is how - - inter, err := net.InterfaceByName(name) - if err != nil { - windows.Close(handle) - return nil, err - } - - addrs, err := inter.Addrs() - if err != nil { - windows.Close(handle) - return nil, err - } - - var ip net.IP - for _, addr := range addrs { - ip = func() net.IP { - switch v := addr.(type) { - case *net.IPNet: - return v.IP.To4() - case *net.IPAddr: - return v.IP.To4() - } - return nil - }() - if ip != nil { - break - } - } - - if ip == nil { - windows.Close(handle) - return nil, errors.New("No IPv4 address found for interface") - } - - // bring up device. - - if err := setStatus(handle, true); err != nil { - windows.Close(handle) - return nil, err - } - - // set tun mode - - mask := ip.String() + "/0" - if err := setTUN(handle, mask); err != nil { - windows.Close(handle) - return nil, err - } - - // start listener - - go func(native *NativeTUN, ifname string) { - // TODO: Fix this very niave implementation - var ( - statusUp bool - statusMTU int - ) - - for ; ; time.Sleep(time.Second) { - intr, err := net.InterfaceByName(name) - if err != nil { - // TODO: handle - return - } - - // Up / Down event - up := (intr.Flags & net.FlagUp) != 0 - if up != statusUp && up { - native.events <- TUNEventUp - } - if up != statusUp && !up { - native.events <- TUNEventDown - } - statusUp = up - - // MTU changes - if intr.MTU != statusMTU { - native.events <- TUNEventMTUUpdate - } - statusMTU = intr.MTU - } - }(tun, name) - - return tun, nil -} diff --git a/src/uapi.go b/src/uapi.go deleted file mode 100644 index caaa498..0000000 --- a/src/uapi.go +++ /dev/null @@ -1,437 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "io" - "net" - "strconv" - "strings" - "sync/atomic" - "time" -) - -type IPCError struct { - Code int64 -} - -func (s *IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.Code) -} - -func (s *IPCError) ErrorCode() int64 { - return s.Code -} - -func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - - device.log.Debug.Println("UAPI: Processing get operation") - - // create lines - - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) - } - - func() { - - // lock required resources - - device.net.mutex.RLock() - defer device.net.mutex.RUnlock() - - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() - - device.routing.mutex.RLock() - defer device.routing.mutex.RUnlock() - - device.peers.mutex.Lock() - defer device.peers.mutex.Unlock() - - // serialize device related values - - if !device.noise.privateKey.IsZero() { - send("private_key=" + device.noise.privateKey.ToHex()) - } - - if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) - } - - if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) - } - - // serialize each peer state - - for _, peer := range device.peers.keyMap { - peer.mutex.RLock() - defer peer.mutex.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) - } - - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() - - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", peer.stats.txBytes)) - send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) - send(fmt.Sprintf("persistent_keepalive_interval=%d", - atomic.LoadUint64(&peer.persistentKeepaliveInterval), - )) - - for _, ip := range device.routing.table.AllowedIPs(peer) { - send("allowed_ip=" + ip.String()) - } - - } - }() - - // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ - Code: ipcErrorIO, - } - } - } - - return nil -} - -func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug - - var peer *Peer - - dummy := false - deviceConfig := true - - for scanner.Scan() { - - // parse line - - line := scanner.Text() - if line == "" { - return nil - } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{Code: ipcErrorProtocol} - } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{Code: ipcErrorInvalid} - } - logDebug.Println("UAPI: Updating device private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.mutex.Lock() - device.net.port = uint16(port) - device.net.mutex.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorPortInUse} - } - - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{Code: ipcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - device.net.mutex.Lock() - device.net.fwmark = uint32(fwmark) - device.net.mutex.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{Code: ipcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") - deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{Code: ipcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI key (device configuration):", key) - return &IPCError{Code: ipcErrorInvalid} - } - } - - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public_key:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - // ignore peer with public key of device - - device.noise.mutex.RLock() - equals := device.noise.publicKey.Equals(publicKey) - device.noise.mutex.RUnlock() - - if equals { - peer = &Peer{} - dummy = true - } - - // find peer referenced - - peer = device.LookupPeer(publicKey) - - if peer == nil { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{Code: ipcErrorInvalid} - } - logDebug.Println("UAPI: Created new peer:", peer.String()) - } - - peer.mutex.Lock() - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - peer.mutex.Unlock() - - case "remove": - - // remove currently selected peer from device - - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{Code: ipcErrorInvalid} - } - if !dummy { - logDebug.Println("UAPI: Removing peer:", peer.String()) - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true - - case "preshared_key": - - // update PSK - - logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) - - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() - - if err != nil { - logError.Println("Failed to set preshared_key:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - case "endpoint": - - // set endpoint destination - - logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) - - err := func() error { - peer.mutex.Lock() - defer peer.mutex.Unlock() - endpoint, err := CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - return nil - }() - - if err != nil { - logError.Println("Failed to set endpoint:", value) - return &IPCError{Code: ipcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update keep-alive interval - - logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent_keepalive_interval:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - old := atomic.SwapUint64( - &peer.persistentKeepaliveInterval, - secs, - ) - - // send immediate keep-alive - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{Code: ipcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepAlive() - } - } - - case "replace_allowed_ips": - - logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) - - if value != "true" { - logError.Println("Failed to set replace_allowed_ips, invalid value:", value) - return &IPCError{Code: ipcErrorInvalid} - } - - if dummy { - continue - } - - device.routing.mutex.Lock() - device.routing.table.RemovePeer(peer) - device.routing.mutex.Unlock() - - case "allowed_ip": - - logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) - - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed_ip:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - if dummy { - continue - } - - ones, _ := network.Mask.Size() - device.routing.mutex.Lock() - device.routing.table.Insert(network.IP, uint(ones), peer) - device.routing.mutex.Unlock() - - default: - logError.Println("Invalid UAPI key (peer configuration):", key) - return &IPCError{Code: ipcErrorInvalid} - } - } - } - - return nil -} - -func ipcHandle(device *Device, socket net.Conn) { - - // create buffered read/writer - - defer socket.Close() - - buffered := func(s io.ReadWriter) *bufio.ReadWriter { - reader := bufio.NewReader(s) - writer := bufio.NewWriter(s) - return bufio.NewReadWriter(reader, writer) - }(socket) - - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError - - switch op { - case "set=1\n": - device.log.Debug.Println("Config, set operation") - status = ipcSetOperation(device, buffered) - - case "get=1\n": - device.log.Debug.Println("Config, get operation") - status = ipcGetOperation(device, buffered) - - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } - - // write status - - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } -} diff --git a/src/uapi_darwin.go b/src/uapi_darwin.go deleted file mode 100644 index 63d4d8d..0000000 --- a/src/uapi_darwin.go +++ /dev/null @@ -1,99 +0,0 @@ -package main - -import ( - "fmt" - "golang.org/x/sys/unix" - "net" - "os" - "path" - "time" -) - -const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - return l.listener.Close() -} - -func (l *UAPIListener) Addr() net.Addr { - return nil -} - -func NewUAPIListener(name string) (net.Listener, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 077) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - listener, err := net.Listen("unix", socketPath) - if err != nil { - return nil, err - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - // watch for deletion of socket - - go func(l *UAPIListener) { - for ; ; time.Sleep(time.Second) { - if _, err := os.Stat(socketPath); os.IsNotExist(err) { - l.connErr <- err - return - } - } - }(uapi) - - // watch for new connections - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} diff --git a/src/uapi_linux.go b/src/uapi_linux.go deleted file mode 100644 index f97a18a..0000000 --- a/src/uapi_linux.go +++ /dev/null @@ -1,171 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "golang.org/x/sys/unix" - "net" - "os" - "path" -) - -const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error - inotifyFd int -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - err1 := unix.Close(l.inotifyFd) - err2 := l.listener.Close() - if err1 != nil { - return err1 - } - return err2 -} - -func (l *UAPIListener) Addr() net.Addr { - return nil -} - -func UAPIListen(name string, file *os.File) (net.Listener, error) { - - // wrap file in listener - - listener, err := net.FileListener(file) - if err != nil { - return nil, err - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - // watch for deletion of socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - uapi.inotifyFd, err = unix.InotifyInit() - if err != nil { - return nil, err - } - - _, err = unix.InotifyAddWatch( - uapi.inotifyFd, - socketPath, - unix.IN_ATTRIB| - unix.IN_DELETE| - unix.IN_DELETE_SELF, - ) - - if err != nil { - return nil, err - } - - go func(l *UAPIListener) { - var buff [4096]byte - for { - // start with lstat to avoid race condition - if _, err := os.Lstat(socketPath); os.IsNotExist(err) { - l.connErr <- err - return - } - unix.Read(uapi.inotifyFd, buff[:]) - } - }(uapi) - - // watch for new connections - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} - -func UAPIOpen(name string) (*os.File, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 0600) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - addr, err := net.ResolveUnixAddr("unix", socketPath) - if err != nil { - return nil, err - } - - listener, err := func() (*net.UnixListener, error) { - - // initial connection attempt - - listener, err := net.ListenUnix("unix", addr) - if err == nil { - return listener, nil - } - - // check if socket already active - - _, err = net.Dial("unix", socketPath) - if err == nil { - return nil, errors.New("unix socket in use") - } - - // cleanup & attempt again - - err = os.Remove(socketPath) - if err != nil { - return nil, err - } - return net.ListenUnix("unix", addr) - }() - - if err != nil { - return nil, err - } - - return listener.File() -} diff --git a/src/uapi_windows.go b/src/uapi_windows.go deleted file mode 100644 index a4599a5..0000000 --- a/src/uapi_windows.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -/* UAPI on windows uses a bidirectional named pipe - */ - -import ( - "fmt" - "github.com/Microsoft/go-winio" - "golang.org/x/sys/windows" - "net" -) - -const ( - ipcErrorIO = -int64(windows.ERROR_BROKEN_PIPE) - ipcErrorProtocol = -int64(windows.ERROR_INVALID_NAME) - ipcErrorInvalid = -int64(windows.ERROR_INVALID_PARAMETER) - ipcErrorPortInUse = -int64(windows.ERROR_ALREADY_EXISTS) -) - -const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s" - -type UAPIListener struct { - listener net.Listener -} - -func (uapi *UAPIListener) Accept() (net.Conn, error) { - return nil, nil -} - -func (uapi *UAPIListener) Close() error { - return uapi.listener.Close() -} - -func (uapi *UAPIListener) Addr() net.Addr { - return nil -} - -func NewUAPIListener(name string) (net.Listener, error) { - path := fmt.Sprintf(PipeNameFmt, name) - return winio.ListenPipe(path, &winio.PipeConfig{ - InputBufferSize: 2048, - OutputBufferSize: 2048, - }) -} diff --git a/src/xchacha20.go b/src/xchacha20.go deleted file mode 100644 index 5d963e0..0000000 --- a/src/xchacha20.go +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2016 Andreas Auernhammer. All rights reserved. -// Use of this source code is governed by a license that can be -// found in the LICENSE file. - -package main - -import ( - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" -) - -func HChaCha20(out *[32]byte, nonce []byte, key *[32]byte) { - - v00 := uint32(0x61707865) - v01 := uint32(0x3320646e) - v02 := uint32(0x79622d32) - v03 := uint32(0x6b206574) - - v04 := binary.LittleEndian.Uint32(key[0:]) - v05 := binary.LittleEndian.Uint32(key[4:]) - v06 := binary.LittleEndian.Uint32(key[8:]) - v07 := binary.LittleEndian.Uint32(key[12:]) - v08 := binary.LittleEndian.Uint32(key[16:]) - v09 := binary.LittleEndian.Uint32(key[20:]) - v10 := binary.LittleEndian.Uint32(key[24:]) - v11 := binary.LittleEndian.Uint32(key[28:]) - v12 := binary.LittleEndian.Uint32(nonce[0:]) - v13 := binary.LittleEndian.Uint32(nonce[4:]) - v14 := binary.LittleEndian.Uint32(nonce[8:]) - v15 := binary.LittleEndian.Uint32(nonce[12:]) - - for i := 0; i < 20; i += 2 { - v00 += v04 - v12 ^= v00 - v12 = (v12 << 16) | (v12 >> 16) - v08 += v12 - v04 ^= v08 - v04 = (v04 << 12) | (v04 >> 20) - v00 += v04 - v12 ^= v00 - v12 = (v12 << 8) | (v12 >> 24) - v08 += v12 - v04 ^= v08 - v04 = (v04 << 7) | (v04 >> 25) - v01 += v05 - v13 ^= v01 - v13 = (v13 << 16) | (v13 >> 16) - v09 += v13 - v05 ^= v09 - v05 = (v05 << 12) | (v05 >> 20) - v01 += v05 - v13 ^= v01 - v13 = (v13 << 8) | (v13 >> 24) - v09 += v13 - v05 ^= v09 - v05 = (v05 << 7) | (v05 >> 25) - v02 += v06 - v14 ^= v02 - v14 = (v14 << 16) | (v14 >> 16) - v10 += v14 - v06 ^= v10 - v06 = (v06 << 12) | (v06 >> 20) - v02 += v06 - v14 ^= v02 - v14 = (v14 << 8) | (v14 >> 24) - v10 += v14 - v06 ^= v10 - v06 = (v06 << 7) | (v06 >> 25) - v03 += v07 - v15 ^= v03 - v15 = (v15 << 16) | (v15 >> 16) - v11 += v15 - v07 ^= v11 - v07 = (v07 << 12) | (v07 >> 20) - v03 += v07 - v15 ^= v03 - v15 = (v15 << 8) | (v15 >> 24) - v11 += v15 - v07 ^= v11 - v07 = (v07 << 7) | (v07 >> 25) - v00 += v05 - v15 ^= v00 - v15 = (v15 << 16) | (v15 >> 16) - v10 += v15 - v05 ^= v10 - v05 = (v05 << 12) | (v05 >> 20) - v00 += v05 - v15 ^= v00 - v15 = (v15 << 8) | (v15 >> 24) - v10 += v15 - v05 ^= v10 - v05 = (v05 << 7) | (v05 >> 25) - v01 += v06 - v12 ^= v01 - v12 = (v12 << 16) | (v12 >> 16) - v11 += v12 - v06 ^= v11 - v06 = (v06 << 12) | (v06 >> 20) - v01 += v06 - v12 ^= v01 - v12 = (v12 << 8) | (v12 >> 24) - v11 += v12 - v06 ^= v11 - v06 = (v06 << 7) | (v06 >> 25) - v02 += v07 - v13 ^= v02 - v13 = (v13 << 16) | (v13 >> 16) - v08 += v13 - v07 ^= v08 - v07 = (v07 << 12) | (v07 >> 20) - v02 += v07 - v13 ^= v02 - v13 = (v13 << 8) | (v13 >> 24) - v08 += v13 - v07 ^= v08 - v07 = (v07 << 7) | (v07 >> 25) - v03 += v04 - v14 ^= v03 - v14 = (v14 << 16) | (v14 >> 16) - v09 += v14 - v04 ^= v09 - v04 = (v04 << 12) | (v04 >> 20) - v03 += v04 - v14 ^= v03 - v14 = (v14 << 8) | (v14 >> 24) - v09 += v14 - v04 ^= v09 - v04 = (v04 << 7) | (v04 >> 25) - } - - binary.LittleEndian.PutUint32(out[0:], v00) - binary.LittleEndian.PutUint32(out[4:], v01) - binary.LittleEndian.PutUint32(out[8:], v02) - binary.LittleEndian.PutUint32(out[12:], v03) - binary.LittleEndian.PutUint32(out[16:], v12) - binary.LittleEndian.PutUint32(out[20:], v13) - binary.LittleEndian.PutUint32(out[24:], v14) - binary.LittleEndian.PutUint32(out[28:], v15) -} - -func XChaCha20Poly1305Encrypt( - dst []byte, - nonceFull *[24]byte, - plaintext []byte, - additionalData []byte, - key *[chacha20poly1305.KeySize]byte, -) []byte { - var nonce [chacha20poly1305.NonceSize]byte - var derivedKey [chacha20poly1305.KeySize]byte - HChaCha20(&derivedKey, nonceFull[:16], key) - aead, _ := chacha20poly1305.New(derivedKey[:]) - copy(nonce[4:], nonceFull[16:]) - return aead.Seal(dst, nonce[:], plaintext, additionalData) -} - -func XChaCha20Poly1305Decrypt( - dst []byte, - nonceFull *[24]byte, - plaintext []byte, - additionalData []byte, - key *[chacha20poly1305.KeySize]byte, -) ([]byte, error) { - var nonce [chacha20poly1305.NonceSize]byte - var derivedKey [chacha20poly1305.KeySize]byte - HChaCha20(&derivedKey, nonceFull[:16], key) - aead, _ := chacha20poly1305.New(derivedKey[:]) - copy(nonce[4:], nonceFull[16:]) - return aead.Open(dst, nonce[:], plaintext, additionalData) -} diff --git a/src/xchacha20_test.go b/src/xchacha20_test.go deleted file mode 100644 index 0f41cf8..0000000 --- a/src/xchacha20_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package main - -import ( - "encoding/hex" - "testing" -) - -type XChaCha20Test struct { - Nonce string - Key string - PT string - CT string -} - -func TestXChaCha20(t *testing.T) { - - tests := []XChaCha20Test{ - { - Nonce: "000000000000000000000000000000000000000000000000", - Key: "0000000000000000000000000000000000000000000000000000000000000000", - PT: "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", - CT: "789e9689e5208d7fd9e1f3c5b5341f48ef18a13e418998addadd97a3693a987f8e82ecd5c1433bfed1af49750c0f1ff29c4174a05b119aa3a9e8333812e0c0feb1299c5949d895ee01dbf50f8395dd84", - }, - { - Nonce: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - Key: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - PT: "0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", - CT: "e1a046aa7f71e2af8b80b6408b2fd8d3a350278cde79c94d9efaa475e1339b3dd490127b", - }, - { - Nonce: "d9a8213e8a697508805c2c171ad54487ead9e3e02d82d5bc", - Key: "979196dbd78526f2f584f7534db3f5824d8ccfa858ca7e09bdd3656ecd36033c", - PT: "43cc6d624e451bbed952c3e071dc6c03392ce11eb14316a94b2fdc98b22fedea", - CT: "53c1e8bef2dbb8f2505ec010a7afe21d5a8e6dd8f987e4ea1a2ed5dfbc844ea400db34496fd2153526c6e87c36694200", - }, - } - - for _, test := range tests { - - nonce, err := hex.DecodeString(test.Nonce) - if err != nil { - panic(err) - } - - key, err := hex.DecodeString(test.Key) - if err != nil { - panic(err) - } - - pt, err := hex.DecodeString(test.PT) - if err != nil { - panic(err) - } - - func() { - var nonceArray [24]byte - var keyArray [32]byte - copy(nonceArray[:], nonce) - copy(keyArray[:], key) - - // test encryption - - ct := XChaCha20Poly1305Encrypt( - nil, - &nonceArray, - pt, - nil, - &keyArray, - ) - ctHex := hex.EncodeToString(ct) - if ctHex != test.CT { - t.Fatal("encryption failed, expected:", test.CT, "got", ctHex) - } - - // test decryption - - ptp, err := XChaCha20Poly1305Decrypt( - nil, - &nonceArray, - ct, - nil, - &keyArray, - ) - if err != nil { - t.Fatal(err) - } - - ptHex := hex.EncodeToString(ptp) - if ptHex != test.PT { - t.Fatal("decryption failed, expected:", test.PT, "got", ptHex) - } - }() - - } - -} |