summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tun/netstack/tun.go68
1 files changed, 52 insertions, 16 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index f61cc2a..8e3ce4b 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -38,13 +38,22 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
+}
+type endpoint netTun
+type Net struct {
+ stack *stack.Stack
+ nicID tcpip.NICID
dnsServers []net.IP
hasV4, hasV6 bool
}
-type endpoint netTun
-type Net netTun
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ if dispatcher == nil {
+ (*netTun)(e).events <- tun.EventUp
+ } else {
+ (*netTun)(e).events <- tun.EventDown
+ }
+
e.dispatcher = dispatcher
}
@@ -100,10 +109,10 @@ func (net *Net) addAddress(ip net.IP) tcpip.Error {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
}
- tcpipErr := net.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ tcpipErr := net.stack.AddProtocolAddress(net.nicID, protoAddr, stack.AddressProperties{})
if tcpipErr == nil && !net.hasV4 {
net.hasV4 = true
- net.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+ net.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: net.nicID})
}
return tcpipErr
} else {
@@ -111,10 +120,10 @@ func (net *Net) addAddress(ip net.IP) tcpip.Error {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
}
- tcpipErr := net.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ tcpipErr := net.stack.AddProtocolAddress(net.nicID, protoAddr, stack.AddressProperties{})
if tcpipErr == nil && !net.hasV6{
net.hasV6 = true
- net.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+ net.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: net.nicID})
}
return tcpipErr
}
@@ -126,26 +135,52 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
HandleLocal: true,
}
- dev := &netTun{
- stack: stack.New(opts),
- events: make(chan tun.Event, 10),
- incomingPacket: make(chan buffer.VectorisedView),
- dnsServers: dnsServers,
- mtu: mtu,
+ return CreateNetTUNWithStack(stack.New(opts), 1, localAddresses, dnsServers, mtu)
+}
+
+func CreateNetTUNWithStack(stack *stack.Stack, nicID tcpip.NICID, localAddresses []net.IP, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+ dev, ep, err := NewNetTUN(stack, mtu)
+ if err != nil {
+ return nil, nil, err
}
- tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+
+ tcpipErr := stack.CreateNIC(nicID, ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
+
+ tnet, err := NewNetAdapter(stack, nicID, localAddresses, dnsServers)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return dev, tnet, nil
+}
+
+func NewNetAdapter(stack *stack.Stack, nicID tcpip.NICID, localAddresses []net.IP, dnsServers []net.IP,) (*Net, error) {
+ tnet := &Net{
+ stack: stack,
+ nicID: nicID,
+ dnsServers: dnsServers,
+ }
for _, ip := range localAddresses {
tcpipErr := tnet.addAddress(ip)
if tcpipErr != nil {
return nil, fmt.Errorf("addAddress(%v): %w", ip, tcpipErr)
}
}
+ return tnet, nil
+}
+
+func NewNetTUN(stack *stack.Stack, mtu int) (*netTun, *endpoint, error) {
+ dev := &netTun{
+ stack: stack,
+ events: make(chan tun.Event, 10),
+ incomingPacket: make(chan buffer.VectorisedView),
+ mtu: mtu,
+ }
- dev.events <- tun.EventUp
- return dev, (*Net)(dev), nil
+ return dev, (*endpoint)(dev), nil
}
func (tun *netTun) Name() (string, error) {
@@ -190,7 +225,8 @@ func (tun *netTun) Flush() error {
}
func (tun *netTun) Close() error {
- tun.stack.RemoveNIC(1)
+// FIXME move somewhere else
+// tun.stack.RemoveNIC(tun.nicID)
if tun.events != nil {
close(tun.events)