summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--conn/bind_linux.go50
-rw-r--r--conn/bind_std.go19
-rw-r--r--conn/bind_windows.go19
-rw-r--r--conn/bindtest/bindtest.go14
-rw-r--r--conn/conn.go37
-rw-r--r--device/allowedips.go42
-rw-r--r--device/allowedips_rand_test.go6
-rw-r--r--device/allowedips_test.go6
-rw-r--r--device/device_test.go6
-rw-r--r--device/endpoint_test.go39
-rw-r--r--device/receive.go1
-rw-r--r--device/uapi.go11
-rw-r--r--go.mod7
-rw-r--r--go.sum13
-rw-r--r--ratelimiter/ratelimiter.go58
-rw-r--r--ratelimiter/ratelimiter_test.go33
-rw-r--r--tun/netstack/examples/http_client.go7
-rw-r--r--tun/netstack/examples/http_server.go6
-rw-r--r--tun/netstack/go.mod1
-rw-r--r--tun/netstack/go.sum4
-rw-r--r--tun/netstack/tun.go143
-rw-r--r--tun/tuntest/tuntest.go10
22 files changed, 247 insertions, 285 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index 7b970e6..da0670a 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -14,6 +14,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "golang.zx2c4.com/go118/netip"
)
type ipv4Source struct {
@@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
- addr, err := parseEndpoint(s)
+ e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
+ if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
- dst.Port = addr.Port
- copy(dst.Addr[:], ipv4)
+ dst.Port = int(e.Port())
+ dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
+ if e.Addr().Is6() {
+ zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
- dst.Port = addr.Port
+ dst.Port = int(e.Port())
dst.ZoneId = zone
- copy(dst.Addr[:], ipv6[:])
+ dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
@@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
}
}
-func (end *LinuxSocketEndpoint) SrcIP() net.IP {
+func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
- return net.IPv4(
- end.src4().Src[0],
- end.src4().Src[1],
- end.src4().Src[2],
- end.src4().Src[3],
- )
+ return netip.AddrFrom4(end.src4().Src)
} else {
- return end.src6().src[:]
+ return netip.AddrFrom16(end.src6().src)
}
}
-func (end *LinuxSocketEndpoint) DstIP() net.IP {
+func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
- return net.IPv4(
- end.dst4().Addr[0],
- end.dst4().Addr[1],
- end.dst4().Addr[2],
- end.dst4().Addr[3],
- )
+ return netip.AddrFrom4(end.dst4().Addr)
} else {
- return end.dst6().Addr[:]
+ return netip.AddrFrom16(end.dst6().Addr)
}
}
@@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
}
func (end *LinuxSocketEndpoint) DstToString() string {
- var udpAddr net.UDPAddr
- udpAddr.IP = end.DstIP()
+ var port int
if !end.isV6 {
- udpAddr.Port = end.dst4().Port
+ port = end.dst4().Port
} else {
- udpAddr.Port = end.dst6().Port
+ port = end.dst6().Port
}
- return udpAddr.String()
+ return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}
func (end *LinuxSocketEndpoint) ClearDst() {
diff --git a/conn/bind_std.go b/conn/bind_std.go
index cb85cfd..a3cbb15 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -10,6 +10,8 @@ import (
"net"
"sync"
"syscall"
+
+ "golang.zx2c4.com/go118/netip"
)
// StdNetBind is meant to be a temporary solution on platforms for which
@@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
var _ Endpoint = (*StdNetEndpoint)(nil)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
- addr, err := parseEndpoint(s)
- return (*StdNetEndpoint)(addr), err
+ e, err := netip.ParseAddrPort(s)
+ return (*StdNetEndpoint)(&net.UDPAddr{
+ IP: e.Addr().AsSlice(),
+ Port: int(e.Port()),
+ Zone: e.Addr().Zone(),
+ }), err
}
func (*StdNetEndpoint) ClearSrc() {}
-func (e *StdNetEndpoint) DstIP() net.IP {
- return (*net.UDPAddr)(e).IP
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+ a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
+ return a
}
-func (e *StdNetEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *StdNetEndpoint) DstToBytes() []byte {
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 42e06ad..26a3af8 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -15,6 +15,7 @@ import (
"unsafe"
"golang.org/x/sys/windows"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn/winrio"
)
@@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
func (*WinRingEndpoint) ClearSrc() {}
-func (e *WinRingEndpoint) DstIP() net.IP {
+func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
- return append([]byte{}, e.data[2:6]...)
+ return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
- return append([]byte{}, e.data[6:22]...)
+ return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
- return nil
+ return netip.Addr{}
}
-func (e *WinRingEndpoint) SrcIP() net.IP {
- return nil // not supported
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
@@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
- addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
- addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
- return addr.String()
+ return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 7d43fb3..6a45896 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -10,8 +10,8 @@ import (
"math/rand"
"net"
"os"
- "strconv"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
)
@@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
-func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
-func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
@@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
- _, port, err := net.SplitHostPort(s)
+ addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
- i, err := strconv.ParseUint(port, 10, 16)
- if err != nil {
- return nil, err
- }
- return ChannelEndpoint(i), nil
+ return ChannelEndpoint(addr.Port()), nil
}
diff --git a/conn/conn.go b/conn/conn.go
index 9cce9ad..35fb6b1 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -9,10 +9,11 @@ package conn
import (
"errors"
"fmt"
- "net"
"reflect"
"runtime"
"strings"
+
+ "golang.zx2c4.com/go118/netip"
)
// A ReceiveFunc receives a single inbound packet from the network.
@@ -68,8 +69,8 @@ type Endpoint interface {
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
- DstIP() net.IP
- SrcIP() net.IP
+ DstIP() netip.Addr
+ SrcIP() netip.Addr
}
var (
@@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
}
return name
}
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
- // ensure that the host is an IP address
-
- host, _, err := net.SplitHostPort(s)
- if err != nil {
- return nil, err
- }
- if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
- // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
- // trying to make sure with a small sanity test that this is a real IP address and
- // not something that's likely to incur DNS lookups.
- host = host[:i]
- }
- if ip := net.ParseIP(host); ip == nil {
- return nil, errors.New("Failed to parse IP address: " + host)
- }
-
- // parse address and port
-
- addr, err := net.ResolveUDPAddr("udp", s)
- if err != nil {
- return nil, err
- }
- ip4 := addr.IP.To4()
- if ip4 != nil {
- addr.IP = ip4
- }
- return addr, err
-}
diff --git a/device/allowedips.go b/device/allowedips.go
index c08399b..7a0b275 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -12,6 +12,8 @@ import (
"net"
"sync"
"unsafe"
+
+ "golang.zx2c4.com/go118/netip"
)
type parentIndirection struct {
@@ -26,7 +28,7 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits net.IP
+ bits []byte
perPeerElem *list.Element
}
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i)
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +231,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
- if !cb(node.bits, node.cidr) {
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
- case net.IPv4len:
- parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- switch len(address) {
+ switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(address)
+ return table.IPv6.lookup(ip)
case net.IPv4len:
- return table.IPv4.lookup(address)
+ return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 16de170..ff56fe6 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -10,6 +10,8 @@ import (
"net"
"sort"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
const (
@@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr4[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr6[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 2059a88..a274997 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -9,6 +9,8 @@ import (
"math/rand"
"net"
"testing"
+
+ "golang.zx2c4.com/go118/netip"
)
type testPairCommonBits struct {
@@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
- allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- allowedIPs.Insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
diff --git a/device/device_test.go b/device/device_test.go
index 29daeb9..84221be 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -11,7 +11,6 @@ import (
"fmt"
"io"
"math/rand"
- "net"
"runtime"
"runtime/pprof"
"sync"
@@ -19,6 +18,7 @@ import (
"testing"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest"
@@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
- ip net.IP
+ ip netip.Addr
}
type SendDirection bool
@@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
- p.ip = net.IPv4(1, 0, 0, byte(i+1))
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 57c361c..f1ae47e 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -7,47 +7,44 @@ package device
import (
"math/rand"
- "net"
+
+ "golang.zx2c4.com/go118/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/receive.go b/device/receive.go
index 5857481..cc34498 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
-
"golang.zx2c4.com/wireguard/conn"
)
diff --git a/device/uapi.go b/device/uapi.go
index 2306183..98e8311 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
- sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
- _, network, err := net.ParseCIDR(value)
+ prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+ device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {
diff --git a/go.mod b/go.mod
index 856bb6c..b510960 100644
--- a/go.mod
+++ b/go.mod
@@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.17
require (
- golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
- golang.org/x/net v0.0.0-20211101193420-4a448f8816b3
- golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
+ golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
+ golang.org/x/net v0.0.0-20211111083644-e5c967477495
+ golang.org/x/sys v0.0.0-20211110154304-99a53858aa08
+ golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
)
diff --git a/go.sum b/go.sum
index 37fe067..78f7367 100644
--- a/go.sum
+++ b/go.sum
@@ -1,16 +1,19 @@
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4=
+golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU=
-golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A=
+golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ=
+golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
+golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go
index 2f7aa2a..8e78d5e 100644
--- a/ratelimiter/ratelimiter.go
+++ b/ratelimiter/ratelimiter.go
@@ -6,9 +6,10 @@
package ratelimiter
import (
- "net"
"sync"
"time"
+
+ "golang.zx2c4.com/go118/netip"
)
const (
@@ -30,8 +31,7 @@ type Ratelimiter struct {
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
- tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
- tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+ table map[netip.Addr]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
@@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
}
rate.stopReset = make(chan struct{})
- rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
- rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+ rate.table = make(map[netip.Addr]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again.
@@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
- for key, entry := range rate.tableIPv4 {
+ for key, entry := range rate.table {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv4, key)
+ delete(rate.table, key)
}
entry.mu.Unlock()
}
- for key, entry := range rate.tableIPv6 {
- entry.mu.Lock()
- if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
- delete(rate.tableIPv6, key)
- }
- entry.mu.Unlock()
- }
-
- return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+ return len(rate.table) == 0
}
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry
- var keyIPv4 [net.IPv4len]byte
- var keyIPv6 [net.IPv6len]byte
-
// lookup entry
-
- IPv4 := ip.To4()
- IPv6 := ip.To16()
-
rate.mu.RLock()
-
- if IPv4 != nil {
- copy(keyIPv4[:], IPv4)
- entry = rate.tableIPv4[keyIPv4]
- } else {
- copy(keyIPv6[:], IPv6)
- entry = rate.tableIPv6[keyIPv6]
- }
-
+ entry = rate.table[ip]
rate.mu.RUnlock()
// make new entry if not found
-
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow()
rate.mu.Lock()
- if IPv4 != nil {
- rate.tableIPv4[keyIPv4] = entry
- if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
- rate.stopReset <- struct{}{}
- }
- } else {
- rate.tableIPv6[keyIPv6] = entry
- if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
- rate.stopReset <- struct{}{}
- }
+ rate.table[ip] = entry
+ if len(rate.table) == 1 {
+ rate.stopReset <- struct{}{}
}
rate.mu.Unlock()
return true
}
// add tokens to entry
-
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
}
// subtract cost of packet
-
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.mu.Unlock()
diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go
index f231fe5..3e06ff7 100644
--- a/ratelimiter/ratelimiter_test.go
+++ b/ratelimiter/ratelimiter_test.go
@@ -6,9 +6,10 @@
package ratelimiter
import (
- "net"
"testing"
"time"
+
+ "golang.zx2c4.com/go118/netip"
)
type result struct {
@@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst",
})
- ips := []net.IP{
- net.ParseIP("127.0.0.1"),
- net.ParseIP("192.168.1.1"),
- net.ParseIP("172.167.2.3"),
- net.ParseIP("97.231.252.215"),
- net.ParseIP("248.97.91.167"),
- net.ParseIP("188.208.233.47"),
- net.ParseIP("104.2.183.179"),
- net.ParseIP("72.129.46.120"),
- net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
- net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
- net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
- net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
- net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
- net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+ ips := []netip.Addr{
+ netip.MustParseAddr("127.0.0.1"),
+ netip.MustParseAddr("192.168.1.1"),
+ netip.MustParseAddr("172.167.2.3"),
+ netip.MustParseAddr("97.231.252.215"),
+ netip.MustParseAddr("248.97.91.167"),
+ netip.MustParseAddr("188.208.233.47"),
+ netip.MustParseAddr("104.2.183.179"),
+ netip.MustParseAddr("72.129.46.120"),
+ netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
+ netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
+ netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
+ netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
+ netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
+ netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
now := time.Now()
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
index 6ac2859..b39b453 100644
--- a/tun/netstack/examples/http_client.go
+++ b/tun/netstack/examples/http_client.go
@@ -1,4 +1,5 @@
//go:build ignore
+// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -10,9 +11,9 @@ package main
import (
"io"
"log"
- "net"
"net/http"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +21,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
index 577c6ea..40f7804 100644
--- a/tun/netstack/examples/http_server.go
+++ b/tun/netstack/examples/http_server.go
@@ -1,4 +1,5 @@
//go:build ignore
+// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -13,6 +14,7 @@ import (
"net"
"net/http"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +22,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420,
)
if err != nil {
diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod
index 8db9f4b..46b57ba 100644
--- a/tun/netstack/go.mod
+++ b/tun/netstack/go.mod
@@ -6,6 +6,7 @@ require (
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
+ golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
)
diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum
index 78c025c..01bfbc7 100644
--- a/tun/netstack/go.sum
+++ b/tun/netstack/go.sum
@@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 24d0835..f1c03f4 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -18,6 +18,7 @@ import (
"strings"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
- dnsServers []net.IP
+ dnsServers []netip.Addr
hasV4, hasV6 bool
}
type endpoint netTun
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
-func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
- if ip4 := ip.To4(); ip4 != nil {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
- }
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
dev.hasV4 = true
- } else {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
- }
+ } else if ip.Is6() {
dev.hasV6 = true
}
}
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
-func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
- if ip4 := ip.To4(); ip4 != nil {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip4),
- Port: uint16(port),
- }, ipv4.ProtocolNumber
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
} else {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip),
- Port: uint16(port),
- }, ipv6.ProtocolNumber
+ protoNumber = ipv6.ProtocolNumber
}
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialTCP(net.stack, fa, pn)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.ListenTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.ListenTCP(net.stack, fa, pn)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
-func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
- if laddr != nil {
+ if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+ addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
- if raddr != nil {
+ if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+ addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
-func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
- c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
- if ip := net.ParseIP(host[:zlen]); ip != nil {
- return []string{host[:zlen]}, nil
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
}
if !isDomainName(host) {
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
- var addrsV4, addrsV6 []net.IP
+ var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV4 = append(addrsV4, net.IP(a.A[:]))
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
- var addrs []net.IP
+ var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
- var addrs []net.IP
+ var addrs []netip.AddrPort
for _, addr := range allAddr {
- if strings.IndexByte(addr, ':') != -1 && acceptV6 {
- addrs = append(addrs, net.ParseIP(addr))
- } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
- addrs = append(addrs, net.ParseIP(addr))
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net.Conn
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else {
- c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
}
if err == nil {
return c, nil
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index d89db71..bdf0467 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -8,13 +8,13 @@ package tuntest
import (
"encoding/binary"
"io"
- "net"
"os"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
)
-func Ping(dst, src net.IP) []byte {
+func Ping(dst, src netip.Addr) []byte {
localPort := uint16(1337)
seq := uint16(0)
@@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
return ^uint16(v)
}
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
@@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
+ copy(ip[12:], src.AsSlice())
+ copy(ip[16:], dst.AsSlice())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)