summaryrefslogtreecommitdiffhomepage
path: root/tun
diff options
context:
space:
mode:
Diffstat (limited to 'tun')
-rw-r--r--tun/netstack/examples/http_client.go7
-rw-r--r--tun/netstack/examples/http_server.go6
-rw-r--r--tun/netstack/go.mod1
-rw-r--r--tun/netstack/go.sum4
-rw-r--r--tun/netstack/tun.go143
-rw-r--r--tun/tuntest/tuntest.go10
6 files changed, 99 insertions, 72 deletions
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
index 6ac2859..b39b453 100644
--- a/tun/netstack/examples/http_client.go
+++ b/tun/netstack/examples/http_client.go
@@ -1,4 +1,5 @@
//go:build ignore
+// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -10,9 +11,9 @@ package main
import (
"io"
"log"
- "net"
"net/http"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +21,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
index 577c6ea..40f7804 100644
--- a/tun/netstack/examples/http_server.go
+++ b/tun/netstack/examples/http_server.go
@@ -1,4 +1,5 @@
//go:build ignore
+// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -13,6 +14,7 @@ import (
"net"
"net/http"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +22,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
- []net.IP{net.ParseIP("192.168.4.29")},
- []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420,
)
if err != nil {
diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod
index 8db9f4b..46b57ba 100644
--- a/tun/netstack/go.mod
+++ b/tun/netstack/go.mod
@@ -6,6 +6,7 @@ require (
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
+ golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
)
diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum
index 78c025c..01bfbc7 100644
--- a/tun/netstack/go.sum
+++ b/tun/netstack/go.sum
@@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index 24d0835..f1c03f4 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -18,6 +18,7 @@ import (
"strings"
"time"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
- dnsServers []net.IP
+ dnsServers []netip.Addr
hasV4, hasV6 bool
}
type endpoint netTun
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
-func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
- if ip4 := ip.To4(); ip4 != nil {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
- }
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
dev.hasV4 = true
- } else {
- protoAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
- }
- tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
- if tcpipErr != nil {
- return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
- }
+ } else if ip.Is6() {
dev.hasV6 = true
}
}
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
-func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
- if ip4 := ip.To4(); ip4 != nil {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip4),
- Port: uint16(port),
- }, ipv4.ProtocolNumber
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
} else {
- return tcpip.FullAddress{
- NIC: 1,
- Addr: tcpip.Address(ip),
- Port: uint16(port),
- }, ipv6.ProtocolNumber
+ protoNumber = ipv6.ProtocolNumber
}
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.DialTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.DialTCP(net.stack, fa, pn)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
- panic("todo: deal with auto addr semantics for nil addr")
+ return net.ListenTCPAddrPort(netip.AddrPort{})
}
- fa, pn := convertToFullAddr(addr.IP, addr.Port)
- return gonet.ListenTCP(net.stack, fa, pn)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
-func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
- if laddr != nil {
+ if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+ addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
- if raddr != nil {
+ if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
- addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+ addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
-func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
- c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
- if ip := net.ParseIP(host[:zlen]); ip != nil {
- return []string{host[:zlen]}, nil
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
}
if !isDomainName(host) {
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
- var addrsV4, addrsV6 []net.IP
+ var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV4 = append(addrsV4, net.IP(a.A[:]))
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
- addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
- var addrs []net.IP
+ var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
- var addrs []net.IP
+ var addrs []netip.AddrPort
for _, addr := range allAddr {
- if strings.IndexByte(addr, ':') != -1 && acceptV6 {
- addrs = append(addrs, net.ParseIP(addr))
- } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
- addrs = append(addrs, net.ParseIP(addr))
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net.Conn
if useUDP {
- c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else {
- c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
}
if err == nil {
return c, nil
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
index d89db71..bdf0467 100644
--- a/tun/tuntest/tuntest.go
+++ b/tun/tuntest/tuntest.go
@@ -8,13 +8,13 @@ package tuntest
import (
"encoding/binary"
"io"
- "net"
"os"
+ "golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
)
-func Ping(dst, src net.IP) []byte {
+func Ping(dst, src netip.Addr) []byte {
localPort := uint16(1337)
seq := uint16(0)
@@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
return ^uint16(v)
}
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
@@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
+ copy(ip[12:], src.AsSlice())
+ copy(ip[16:], dst.AsSlice())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)