summaryrefslogtreecommitdiffhomepage
path: root/tun/netstack/tun.go
diff options
context:
space:
mode:
authorMikael Magnusson <mikma@users.sourceforge.net>2021-01-13 23:31:35 +0100
committerMikael Magnusson <mikma@users.sourceforge.net>2021-11-09 23:11:19 +0100
commitc9f6cf925916b1f2553e24b2f629b43b969505e7 (patch)
tree6d50978cf81bce08e8533d06c434a4411ced2e2d /tun/netstack/tun.go
parentebef2c3fe37b723843fb84695a25df2d4db40038 (diff)
netstack: add CreateNetTUNWithStack
Allow the caller to specify the stack to make it possible for more complex scenarios with multiple network interfaces. Signed-off-by: Mikael Magnusson <mikma@users.sourceforge.net>
Diffstat (limited to 'tun/netstack/tun.go')
-rw-r--r--tun/netstack/tun.go68
1 files changed, 52 insertions, 16 deletions
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
index f61cc2a..8e3ce4b 100644
--- a/tun/netstack/tun.go
+++ b/tun/netstack/tun.go
@@ -38,13 +38,22 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
+}
+type endpoint netTun
+type Net struct {
+ stack *stack.Stack
+ nicID tcpip.NICID
dnsServers []net.IP
hasV4, hasV6 bool
}
-type endpoint netTun
-type Net netTun
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ if dispatcher == nil {
+ (*netTun)(e).events <- tun.EventUp
+ } else {
+ (*netTun)(e).events <- tun.EventDown
+ }
+
e.dispatcher = dispatcher
}
@@ -100,10 +109,10 @@ func (net *Net) addAddress(ip net.IP) tcpip.Error {
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
}
- tcpipErr := net.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ tcpipErr := net.stack.AddProtocolAddress(net.nicID, protoAddr, stack.AddressProperties{})
if tcpipErr == nil && !net.hasV4 {
net.hasV4 = true
- net.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+ net.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: net.nicID})
}
return tcpipErr
} else {
@@ -111,10 +120,10 @@ func (net *Net) addAddress(ip net.IP) tcpip.Error {
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
}
- tcpipErr := net.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ tcpipErr := net.stack.AddProtocolAddress(net.nicID, protoAddr, stack.AddressProperties{})
if tcpipErr == nil && !net.hasV6{
net.hasV6 = true
- net.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+ net.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: net.nicID})
}
return tcpipErr
}
@@ -126,26 +135,52 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
HandleLocal: true,
}
- dev := &netTun{
- stack: stack.New(opts),
- events: make(chan tun.Event, 10),
- incomingPacket: make(chan buffer.VectorisedView),
- dnsServers: dnsServers,
- mtu: mtu,
+ return CreateNetTUNWithStack(stack.New(opts), 1, localAddresses, dnsServers, mtu)
+}
+
+func CreateNetTUNWithStack(stack *stack.Stack, nicID tcpip.NICID, localAddresses []net.IP, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+ dev, ep, err := NewNetTUN(stack, mtu)
+ if err != nil {
+ return nil, nil, err
}
- tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev))
+
+ tcpipErr := stack.CreateNIC(nicID, ep)
if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
+
+ tnet, err := NewNetAdapter(stack, nicID, localAddresses, dnsServers)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return dev, tnet, nil
+}
+
+func NewNetAdapter(stack *stack.Stack, nicID tcpip.NICID, localAddresses []net.IP, dnsServers []net.IP,) (*Net, error) {
+ tnet := &Net{
+ stack: stack,
+ nicID: nicID,
+ dnsServers: dnsServers,
+ }
for _, ip := range localAddresses {
tcpipErr := tnet.addAddress(ip)
if tcpipErr != nil {
return nil, fmt.Errorf("addAddress(%v): %w", ip, tcpipErr)
}
}
+ return tnet, nil
+}
+
+func NewNetTUN(stack *stack.Stack, mtu int) (*netTun, *endpoint, error) {
+ dev := &netTun{
+ stack: stack,
+ events: make(chan tun.Event, 10),
+ incomingPacket: make(chan buffer.VectorisedView),
+ mtu: mtu,
+ }
- dev.events <- tun.EventUp
- return dev, (*Net)(dev), nil
+ return dev, (*endpoint)(dev), nil
}
func (tun *netTun) Name() (string, error) {
@@ -190,7 +225,8 @@ func (tun *netTun) Flush() error {
}
func (tun *netTun) Close() error {
- tun.stack.RemoveNIC(1)
+// FIXME move somewhere else
+// tun.stack.RemoveNIC(tun.nicID)
if tun.events != nil {
close(tun.events)