diff options
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r-- | tun/netstack/tun.go | 105 |
1 files changed, 34 insertions, 71 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index c26d8ed..b0e7b70 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -6,6 +6,7 @@ package netstack import ( + "bytes" "context" "crypto/rand" "encoding/binary" @@ -23,10 +24,11 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -37,69 +39,16 @@ import ( ) type netTun struct { + ep *channel.Endpoint stack *stack.Stack - dispatcher stack.NetworkDispatcher events chan tun.Event - incomingPacket chan buffer.VectorisedView + incomingPacket chan *bufferv2.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool } -type ( - endpoint netTun - Net netTun -) - -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher -} - -func (e *endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *endpoint) MTU() uint32 { - mtu, err := (*netTun)(e).MTU() - if err != nil { - panic(err) - } - return uint32(mtu) -} - -func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityNone -} - -func (*endpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (*endpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (*endpoint) Wait() {} - -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - return nil -} - -func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - panic("not implemented") -} - -func (e *endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error { - panic("not implemented") -} - -func (*endpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareNone -} - -func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { -} +type Net netTun func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { opts := stack.Options{ @@ -108,13 +57,15 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, HandleLocal: true, } dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan buffer.VectorisedView), + incomingPacket: make(chan *bufferv2.View), dnsServers: dnsServers, mtu: mtu, } - tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev)) + dev.ep.AddNotify(dev) + tcpipErr := dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } @@ -167,6 +118,7 @@ func (tun *netTun) Read(buf []byte, offset int) (int, error) { if !ok { return 0, os.ErrClosed } + return view.Read(buf[offset:]) } @@ -176,17 +128,29 @@ func (tun *netTun) Write(buf []byte, offset int) (int, error) { return 0, nil } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})}) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: - tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb) + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) case 6: - tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb) + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) } return len(buf), nil } +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt == nil { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + func (tun *netTun) Flush() error { return nil } @@ -197,9 +161,9 @@ func (tun *netTun) Close() error { if tun.events != nil { close(tun.events) } - if tun.incomingPacket != nil { - close(tun.incomingPacket) - } + + tun.ep.Close() + return nil } @@ -434,11 +398,10 @@ func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, fmt.Errorf("ping write: mismatched protocols") } - buf := buffer.NewViewFromBytes(p) - rdr := buf.Reader() + buf := bytes.NewReader(p) rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) // won't block, no deadlines - n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{ + n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ To: &rfa, }) if tcpipErr != nil { @@ -453,8 +416,8 @@ func (pc *PingConn) Write(p []byte) (n int, err error) { } func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - e, notifyCh := waiter.NewChannelEntry(nil) - pc.wq.EventRegister(&e, waiter.EventIn) + e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) + pc.wq.EventRegister(&e) defer pc.wq.EventUnregister(&e) select { @@ -488,7 +451,7 @@ func (pc *PingConn) SetDeadline(t time.Time) error { } func (pc *PingConn) SetReadDeadline(t time.Time) error { - pc.deadline.Reset(t.Sub(time.Now())) + pc.deadline.Reset(time.Until(t)) return nil } |