diff options
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r-- | tun/netstack/tun.go | 291 |
1 files changed, 267 insertions, 24 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index fb7f07d..97983b4 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -14,8 +14,10 @@ import ( "io" "net" "os" + "regexp" "strconv" "strings" + "sync" "time" "golang.zx2c4.com/go118/netip" @@ -29,8 +31,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type netTun struct { @@ -101,7 +105,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, HandleLocal: true, } dev := &netTun{ @@ -270,6 +274,10 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er return gonet.DialUDP(net.stack, lfa, rfa, pn) } +func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { + return net.DialUDPAddrPort(laddr, netip.AddrPort{}) +} + func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { var la, ra netip.AddrPort if laddr != nil { @@ -281,6 +289,233 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { return net.DialUDPAddrPort(la, ra) } +func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { + return net.DialUDP(laddr, nil) +} + +type PingConn struct { + laddr PingAddr + raddr PingAddr + wq waiter.Queue + ep tcpip.Endpoint + mu sync.RWMutex + deadline time.Time + deadlineBreaker chan struct{} +} + +type PingAddr struct{ addr netip.Addr } + +func (ia PingAddr) String() string { + return ia.addr.String() +} + +func (ia PingAddr) Network() string { + if ia.addr.Is4() { + return "ping4" + } else if ia.addr.Is6() { + return "ping6" + } + return "ping" +} + +func (ia PingAddr) Addr() netip.Addr { + return ia.addr +} + +func PingAddrFromAddr(addr netip.Addr) *PingAddr { + return &PingAddr{addr} +} + +func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { + if !laddr.IsValid() && !raddr.IsValid() { + return nil, errors.New("ping dial: invalid address") + } + v6 := laddr.Is6() || raddr.Is6() + bind := laddr.IsValid() + if !bind { + if v6 { + laddr = netip.IPv6Unspecified() + } else { + laddr = netip.IPv4Unspecified() + } + } + + tn := icmp.ProtocolNumber4 + pn := ipv4.ProtocolNumber + if v6 { + tn = icmp.ProtocolNumber6 + pn = ipv6.ProtocolNumber + } + + pc := &PingConn{ + laddr: PingAddr{laddr}, + deadlineBreaker: make(chan struct{}, 1), + } + + ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) + if tcpipErr != nil { + return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) + } + pc.ep = ep + + if bind { + fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) + if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping bind: %s", tcpipErr) + } + } + + if raddr.IsValid() { + pc.raddr = PingAddr{raddr} + fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) + if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping connect: %s", tcpipErr) + } + } + + return pc, nil +} + +func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { + return net.DialPingAddr(laddr, netip.Addr{}) +} + +func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { + var la, ra netip.Addr + if laddr != nil { + la = laddr.addr + } + if raddr != nil { + ra = raddr.addr + } + return net.DialPingAddr(la, ra) +} + +func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { + var la netip.Addr + if laddr != nil { + la = laddr.addr + } + return net.ListenPingAddr(la) +} + +func (pc *PingConn) LocalAddr() net.Addr { + return pc.laddr +} + +func (pc *PingConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *PingConn) Close() error { + close(pc.deadlineBreaker) + pc.ep.Close() + return nil +} + +func (pc *PingConn) SetWriteDeadline(t time.Time) error { + return errors.New("not implemented") +} + +func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + var na netip.Addr + switch v := addr.(type) { + case *PingAddr: + na = v.addr + case *net.IPAddr: + na = netip.AddrFromSlice(v.IP) + default: + return 0, fmt.Errorf("ping write: wrong net.Addr type") + } + if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { + return 0, fmt.Errorf("ping write: mismatched protocols") + } + + buf := buffer.NewViewFromBytes(p) + rdr := buf.Reader() + rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) + // won't block, no deadlines + n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{ + To: &rfa, + }) + if tcpipErr != nil { + return int(n64), fmt.Errorf("ping write: %s", tcpipErr) + } + + return int(n64), nil +} + +func (pc *PingConn) Write(p []byte) (n int, err error) { + return pc.WriteTo(p, &pc.raddr) +} + +func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + e, notifyCh := waiter.NewChannelEntry(nil) + pc.wq.EventRegister(&e, waiter.EventIn) + defer pc.wq.EventUnregister(&e) + + ready := false + + for !ready { + pc.mu.RLock() + deadlineBreaker := pc.deadlineBreaker + deadline := pc.deadline + pc.mu.RUnlock() + + if deadline.IsZero() { + select { + case <-deadlineBreaker: + case <-notifyCh: + ready = true + } + } else { + t := time.NewTimer(deadline.Sub(time.Now())) + defer t.Stop() + + select { + case <-t.C: + return 0, nil, os.ErrDeadlineExceeded + + case <-deadlineBreaker: + case <-notifyCh: + ready = true + } + } + } + + w := tcpip.SliceWriter(p) + + res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if tcpipErr != nil { + return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) + } + + addr = &PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))} + return res.Count, addr, nil +} + +func (pc *PingConn) Read(p []byte) (n int, err error) { + n, _, err = pc.ReadFrom(p) + return +} + +func (pc *PingConn) SetDeadline(t time.Time) error { + // pc.SetWriteDeadline is unimplemented + + return pc.SetReadDeadline(t) +} + +func (pc *PingConn) SetReadDeadline(t time.Time) error { + pc.mu.Lock() + defer pc.mu.Unlock() + close(pc.deadlineBreaker) + pc.deadlineBreaker = make(chan struct{}, 1) + pc.deadline = t + return nil +} + var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") @@ -755,33 +990,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er return now.Add(timeout), nil } +var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) + func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { if ctx == nil { panic("nil context") } - var acceptV4, acceptV6, useUDP bool - if len(network) == 3 { + var acceptV4, acceptV6 bool + matches := protoSplitter.FindStringSubmatch(network) + if matches == nil { + return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} + } else if len(matches[2]) == 0 { acceptV4 = true acceptV6 = true - } else if len(network) == 4 { - acceptV4 = network[3] == '4' - acceptV6 = network[3] == '6' - } - if !acceptV4 && !acceptV6 { - return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} - } - if network[:3] == "udp" { - useUDP = true - } else if network[:3] != "tcp" { - return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} - } - host, sport, err := net.SplitHostPort(address) - if err != nil { - return nil, &net.OpError{Op: "dial", Err: err} + } else { + acceptV4 = matches[2][0] == '4' + acceptV6 = !acceptV4 } - port, err := strconv.Atoi(sport) - if err != nil || port < 0 || port > 65535 { - return nil, &net.OpError{Op: "dial", Err: errNumericPort} + var host string + var port int + if matches[1] == "ping" { + host = address + } else { + var sport string + var err error + host, sport, err = net.SplitHostPort(address) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + port, err = strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errNumericPort} + } } allAddr, err := tnet.LookupContextHost(ctx, host) if err != nil { @@ -829,10 +1069,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. } var c net.Conn - if useUDP { - c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) - } else { + switch matches[1] { + case "tcp": c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) + case "udp": + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) + case "ping": + c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) } if err == nil { return c, nil |