summaryrefslogtreecommitdiffhomepage
path: root/conn/bind_std.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-02-22 02:01:50 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-02-23 20:00:57 +0100
commita4f8e83d5d9f477554971e90e9ab85922f506ea9 (patch)
tree5249ac2dbdc8cbb6a7d2d40814b07d7d1f38ad4d /conn/bind_std.go
parentc69481f1b3b4b37b9c16f997a5d8d91367d9bfee (diff)
conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r--conn/bind_std.go180
1 files changed, 180 insertions, 0 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
new file mode 100644
index 0000000..193c4fe
--- /dev/null
+++ b/conn/bind_std.go
@@ -0,0 +1,180 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "net"
+ "syscall"
+)
+
+// StdNetBind is meant to be a temporary solution on platforms for which
+// the sticky socket / source caching behavior has not yet been implemented.
+// It uses the Go's net package to implement networking.
+// See LinuxSocketBind for a proper implementation on the Linux platform.
+type StdNetBind struct {
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+}
+
+func NewStdNetBind() Bind { return &StdNetBind{} }
+
+type StdNetEndpoint net.UDPAddr
+
+var _ Bind = (*StdNetBind)(nil)
+var _ Endpoint = (*StdNetEndpoint)(nil)
+
+func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
+ addr, err := parseEndpoint(s)
+ return (*StdNetEndpoint)(addr), err
+}
+
+func (*StdNetEndpoint) ClearSrc() {}
+
+func (e *StdNetEndpoint) DstIP() net.IP {
+ return (*net.UDPAddr)(e).IP
+}
+
+func (e *StdNetEndpoint) SrcIP() net.IP {
+ return nil // not supported
+}
+
+func (e *StdNetEndpoint) DstToBytes() []byte {
+ addr := (*net.UDPAddr)(e)
+ out := addr.IP.To4()
+ if out == nil {
+ out = addr.IP
+ }
+ out = append(out, byte(addr.Port&0xff))
+ out = append(out, byte((addr.Port>>8)&0xff))
+ return out
+}
+
+func (e *StdNetEndpoint) DstToString() string {
+ return (*net.UDPAddr)(e).String()
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return ""
+}
+
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+ conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Retrieve port.
+ laddr := conn.LocalAddr()
+ uaddr, err := net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ return conn, uaddr.Port, nil
+}
+
+func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
+ var err error
+ var tries int
+
+ if bind.ipv4 != nil || bind.ipv6 != nil {
+ return 0, ErrBindAlreadyOpen
+ }
+
+again:
+ port := int(uport)
+
+ bind.ipv4, port, err = listenNet("udp4", port)
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ bind.ipv4 = nil
+ return 0, err
+ }
+
+ bind.ipv6, port, err = listenNet("udp6", port)
+ if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+ bind.ipv4.Close()
+ bind.ipv4 = nil
+ bind.ipv6 = nil
+ tries++
+ goto again
+ }
+ if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ bind.ipv4.Close()
+ bind.ipv4 = nil
+ bind.ipv6 = nil
+ return 0, err
+ }
+ if bind.ipv4 == nil && bind.ipv6 == nil {
+ return 0, syscall.EAFNOSUPPORT
+ }
+ return uint16(port), nil
+}
+
+func (bind *StdNetBind) Close() error {
+ var err1, err2 error
+ if bind.ipv4 != nil {
+ err1 = bind.ipv4.Close()
+ bind.ipv4 = nil
+ }
+ if bind.ipv6 != nil {
+ err2 = bind.ipv6.Close()
+ bind.ipv6 = nil
+ }
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ if bind.ipv4 == nil {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
+ if endpoint != nil {
+ endpoint.IP = endpoint.IP.To4()
+ }
+ return n, (*StdNetEndpoint)(endpoint), err
+}
+
+func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ if bind.ipv6 == nil {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
+ return n, (*StdNetEndpoint)(endpoint), err
+}
+
+func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
+ var err error
+ nend, ok := endpoint.(*StdNetEndpoint)
+ if !ok {
+ return ErrWrongEndpointType
+ }
+ if nend.IP.To4() != nil {
+ if bind.ipv4 == nil {
+ return syscall.EAFNOSUPPORT
+ }
+ if bind.blackhole4 {
+ return nil
+ }
+ _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ } else {
+ if bind.ipv6 == nil {
+ return syscall.EAFNOSUPPORT
+ }
+ if bind.blackhole6 {
+ return nil
+ }
+ _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ }
+ return err
+}