diff options
Diffstat (limited to 'tun')
-rw-r--r-- | tun/netstack/examples/http_client.go | 7 | ||||
-rw-r--r-- | tun/netstack/examples/http_server.go | 6 | ||||
-rw-r--r-- | tun/netstack/go.mod | 1 | ||||
-rw-r--r-- | tun/netstack/go.sum | 4 | ||||
-rw-r--r-- | tun/netstack/tun.go | 143 | ||||
-rw-r--r-- | tun/tuntest/tuntest.go | 10 |
6 files changed, 99 insertions, 72 deletions
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go index 6ac2859..b39b453 100644 --- a/tun/netstack/examples/http_client.go +++ b/tun/netstack/examples/http_client.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -10,9 +11,9 @@ package main import ( "io" "log" - "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +21,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 1420) if err != nil { log.Panic(err) diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go index 577c6ea..40f7804 100644 --- a/tun/netstack/examples/http_server.go +++ b/tun/netstack/examples/http_server.go @@ -1,4 +1,5 @@ //go:build ignore +// +build ignore /* SPDX-License-Identifier: MIT * @@ -13,6 +14,7 @@ import ( "net" "net/http" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -20,8 +22,8 @@ import ( func main() { tun, tnet, err := netstack.CreateNetTUN( - []net.IP{net.ParseIP("192.168.4.29")}, - []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, 1420, ) if err != nil { diff --git a/tun/netstack/go.mod b/tun/netstack/go.mod index 8db9f4b..46b57ba 100644 --- a/tun/netstack/go.mod +++ b/tun/netstack/go.mod @@ -6,6 +6,7 @@ require ( golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect + golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6 ) diff --git a/tun/netstack/go.sum b/tun/netstack/go.sum index 78c025c..01bfbc7 100644 --- a/tun/netstack/go.sum +++ b/tun/netstack/go.sum @@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg= +golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E= +golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw= golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 24d0835..f1c03f4 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" @@ -38,7 +39,7 @@ type netTun struct { events chan tun.Event incomingPacket chan buffer.VectorisedView mtu int - dnsServers []net.IP + dnsServers []netip.Addr hasV4, hasV6 bool } type endpoint netTun @@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType { func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } -func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) { +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, @@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { - if ip4 := ip.To4(); ip4 != nil { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip4).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr) - } + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { dev.hasV4 = true - } else { - protoAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: tcpip.Address(ip).WithPrefix(), - } - tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) - if tcpipErr != nil { - return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) - } + } else if ip.Is6() { dev.hasV6 = true } } @@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } -func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { - if ip4 := ip.To4(); ip4 != nil { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip4), - Port: uint16(port), - }, ipv4.ProtocolNumber +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber } else { - return tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(ip), - Port: uint16(port), - }, ipv6.ProtocolNumber + protoNumber = ipv6.ProtocolNumber } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialContextTCP(ctx, net.stack, fa, pn) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.DialTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.DialTCP(net.stack, fa, pn) + return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { - panic("todo: deal with auto addr semantics for nil addr") + return net.ListenTCPAddrPort(netip.AddrPort{}) } - fa, pn := convertToFullAddr(addr.IP, addr.Port) - return gonet.ListenTCP(net.stack, fa, pn) + return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port))) } -func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber - if laddr != nil { + if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(laddr.IP, laddr.Port) + addr, pn = convertToFullAddr(laddr) lfa = &addr } - if raddr != nil { + if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress - addr, pn = convertToFullAddr(raddr.IP, raddr.Port) + addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port)) + } + if raddr != nil { + ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by return p, h, nil } -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { @@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest var c net.Conn var err error if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { @@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, zlen = zidx } } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil } if !isDomainName(host) { @@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, server string error } - var addrsV4, addrsV6 []net.IP + var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ @@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV4 = append(addrsV4, net.IP(a.A[:])) + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() @@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } break loop } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { @@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP + var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { @@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } - var addrs []net.IP + var addrs []netip.AddrPort for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { @@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. var c net.Conn if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) } else { - c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) } if err == nil { return c, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index d89db71..bdf0467 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -8,13 +8,13 @@ package tuntest import ( "encoding/binary" "io" - "net" "os" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/tun" ) -func Ping(dst, src net.IP) []byte { +func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) @@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 { return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 @@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) |