summaryrefslogtreecommitdiffhomepage
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go105
1 files changed, 34 insertions, 71 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index c26d8ed..b0e7b70 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -6,6 +6,7 @@
package netstack
import (
+ "bytes"
"context"
"crypto/rand"
"encoding/binary"
@@ -23,10 +24,11 @@ import (
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -37,69 +39,16 @@ import (
)
type netTun struct {
+ ep *channel.Endpoint
stack *stack.Stack
- dispatcher stack.NetworkDispatcher
events chan tun.Event
- incomingPacket chan buffer.VectorisedView
+ incomingPacket chan *bufferv2.View
mtu int
dnsServers []netip.Addr
hasV4, hasV6 bool
}
-type (
- endpoint netTun
- Net netTun
-)
-
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
-}
-
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-func (e *endpoint) MTU() uint32 {
- mtu, err := (*netTun)(e).MTU()
- if err != nil {
- panic(err)
- }
- return uint32(mtu)
-}
-
-func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return stack.CapabilityNone
-}
-
-func (*endpoint) MaxHeaderLength() uint16 {
- return 0
-}
-
-func (*endpoint) LinkAddress() tcpip.LinkAddress {
- return ""
-}
-
-func (*endpoint) Wait() {}
-
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
- e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
- return nil
-}
-
-func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- panic("not implemented")
-}
-
-func (e *endpoint) WriteRawPacket(*stack.PacketBuffer) tcpip.Error {
- panic("not implemented")
-}
-
-func (*endpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareNone
-}
-
-func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
-}
+type Net netTun
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
@@ -108,13 +57,15 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
HandleLocal: true,
}
dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts),
events: make(chan tun.Event, 10),
- incomingPacket: make(chan buffer.VectorisedView),
+ incomingPacket: make(chan *bufferv2.View),
dnsServers: dnsServers,
mtu: mtu,
}
- tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+ dev.ep.AddNotify(dev)
+ tcpipErr := dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
@@ -167,6 +118,7 @@ func (tun *netTun) Read(buf []byte, offset int) (int, error) {
if !ok {
return 0, os.ErrClosed
}
+
return view.Read(buf[offset:])
}
@@ -176,17 +128,29 @@ func (tun *netTun) Write(buf []byte, offset int) (int, error) {
return 0, nil
}
- pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})})
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
switch packet[0] >> 4 {
case 4:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb)
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6:
- tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb)
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
}
return len(buf), nil
}
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt == nil {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
+}
+
func (tun *netTun) Flush() error {
return nil
}
@@ -197,9 +161,9 @@ func (tun *netTun) Close() error {
if tun.events != nil {
close(tun.events)
}
- if tun.incomingPacket != nil {
- close(tun.incomingPacket)
- }
+
+ tun.ep.Close()
+
return nil
}
@@ -434,11 +398,10 @@ func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, fmt.Errorf("ping write: mismatched protocols")
}
- buf := buffer.NewViewFromBytes(p)
- rdr := buf.Reader()
+ buf := bytes.NewReader(p)
rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
// won't block, no deadlines
- n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
To: &rfa,
})
if tcpipErr != nil {
@@ -453,8 +416,8 @@ func (pc *PingConn) Write(p []byte) (n int, err error) {
}
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
- e, notifyCh := waiter.NewChannelEntry(nil)
- pc.wq.EventRegister(&e, waiter.EventIn)
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
defer pc.wq.EventUnregister(&e)
select {
@@ -488,7 +451,7 @@ func (pc *PingConn) SetDeadline(t time.Time) error {
}
func (pc *PingConn) SetReadDeadline(t time.Time) error {
- pc.deadline.Reset(t.Sub(time.Now()))
+ pc.deadline.Reset(time.Until(t))
return nil
}