diff options
Diffstat (limited to 'pkg/tcpip/config')
-rw-r--r-- | pkg/tcpip/config/config.go | 526 |
1 files changed, 526 insertions, 0 deletions
diff --git a/pkg/tcpip/config/config.go b/pkg/tcpip/config/config.go new file mode 100644 index 000000000..3c1f91ecc --- /dev/null +++ b/pkg/tcpip/config/config.go @@ -0,0 +1,526 @@ +package config + +import ( + "bufio" + "encoding/base64" + "encoding/hex" + "fmt" + "log" + "net" + "os" + "runtime" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/tcpip/link/wireguard" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" +// wg_tun "golang.zx2c4.com/wireguard/tun" + + "gopkg.in/yaml.v3" +) + +type Route struct { + To string `yaml:"to"` + Via string `yaml:"via"` + Metric int `yaml:"metric"` + Mark uint32 `yaml:"mask"` + Markmask uint32 `yaml:"markmask"` +} + +type Common struct { + Match struct { + Name string `yaml:"name"` + } `yaml:"match"` + Addresses []string `yaml:"addresses"` + Nameservers struct { + Addresses []string `yaml:"addresses"` + } `yaml:"nameservers"` + Macaddress string `yaml:"macaddress"` + Routes []Route `yaml:"routes"` + Mtu uint32 `yaml:"mtu"` +} + +type Ethernet struct { + Common `yaml:",inline"` +} + +type Tuntap struct { + Common `yaml:",inline"` + Mode string `yaml:"mode"` + Name string `yaml:"name"` +} + +type WireguardKey []byte + +type Routes []tcpip.Route + +func (wgKey *WireguardKey) UnmarshalYAML(value *yaml.Node) error{ + key, err := base64.StdEncoding.DecodeString(value.Value) + fmt.Println("UnmarshalYAML", key, err) + *wgKey = key + return err +} + +type WireguardPeer struct { + PublicKey WireguardKey `yaml:"public_key"` + Endpoint string `yaml:"endpoint"` + AllowedIPs []string `yaml:"allowed_ips"` + PersistentKeepalive int `yaml:"persistent_keepalive"` +} + +func (peer WireguardPeer) String() string{ + return fmt.Sprintf("{PublicKey=%v, Endpoint=%v, AllowedIPs=%v, PersistentKeepalive=%v}", peer.PublicKey, peer.Endpoint, peer.AllowedIPs, peer.PersistentKeepalive) +} + +type Wireguard struct { + Common `yaml:",inline"` + Name string `yaml:"name"` + ListenPort uint16 `yaml:"listen_port"` + PrivateKey WireguardKey `yaml:"private_key"` + Peers []*WireguardPeer `yaml:"peers"` +} + +type Tunnel struct { + Common `yaml:",inline"` + Mode string `yaml:"mode"` + Local string `yaml:"local"` + Remote string `yaml:"remote"` + + Conn *net.UDPConn + Sd *os.File +} + +type Netplan struct { + Network struct { + Version int `yaml:"version"` + Renderer string `yaml:"renderer"` + Ethernets map[string] *Ethernet `yaml:"ethernets"` + Tuntaps map[string] *Tuntap `yaml:"tuntaps"` + Wireguards map[string] *Wireguard `yaml:"wireguards"` + Tunnels map[string] *Tunnel `yaml:"tunnels"` + } `yaml:"network"` +} + +func CheckError(err error) { + if err != nil { + log.Fatal("Error: " , err) + } +} + +func addTunLink(s *stack.Stack, nic tcpip.NICID, id string, tap bool, addr tcpip.LinkAddress, tuntap *Tuntap) { + var err error + + mtu := tuntap.Mtu + tunName := tuntap.Name + + if mtu == 0 { + mtu, err = rawfile.GetMTU(tunName) + if err != nil { + log.Fatal("GetMTU", err) + } + } + + var fd int + if tap { + fd, err = tun.OpenTAP(tunName) + } else { + fd, err = tun.Open(tunName) + } + if err != nil { + log.Fatalf("Open %s %b %s", err, tap, tunName) + } + + linkEP, err := fdbased.New(&fdbased.Options{ + FDs: []int{fd}, + MTU: mtu, + EthernetHeader: tap, + Address: addr, + }) + if err := s.CreateNICWithOptions(nic, linkEP, stack.NICOptions{Name: id, Disabled: true}); err != nil { + log.Fatal("CreateNIC", err) + } + + if tap { + if err := s.AddAddress(nic, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + log.Fatal("AddAddress", err) + } + } +} + +func addRouterLink(s *stack.Stack, nic tcpip.NICID, id string, addr tcpip.LinkAddress, + tun *Tunnel) { + + ServerAddr,err := net.ResolveUDPAddr("udp", tun.Remote) + CheckError(err) + + LocalAddr, err := net.ResolveUDPAddr("udp", tun.Local) + CheckError(err) + + conn, err := net.DialUDP("udp", LocalAddr, ServerAddr) + CheckError(err) + + tun.Conn = conn + + fmt.Println("Tunnel ", conn) + + sd, err := conn.File() + CheckError(err) + + tun.Sd = sd + + conn.Close() + + var fd int + fd = int(sd.Fd()) + //TestFinialize(&tun.Conn) + runtime.GC() + linkEP, err := fdbased.New(&fdbased.Options{ + FDs: []int{fd}, + MTU: tun.Mtu, + EthernetHeader: true, +// EthernetHeader: false, + Address: addr, + }) + CheckError(err) + + fmt.Println("addRouterLink MTU ", tun.Mtu, tun.Conn) + //channelSize := 128 + // linkEP := channel.New(channelSize, mtu, addr) + if err := s.CreateNICWithOptions(nic, linkEP, stack.NICOptions{Name: id, Disabled: true}); err != nil { + log.Fatal("CreateNIC ", id, err) + } + + if err := s.AddAddress(nic, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + log.Fatal("AddAddress", err) + } + +// client(linkEP.C) + fmt.Println("Tunnel ", tun.Conn) +} + +func addWgLink(s *stack.Stack, nic tcpip.NICID, interfaceName string, addr tcpip.LinkAddress) *device.Device { + loglevel := device.LogLevelDebug + + chanSize := 1024 + var chanMtu uint32 = 1420 + ep := channel.New(chanSize, chanMtu, addr) + + logger := device.NewLogger(loglevel, "(wg_tunnel) ") + + //mtu := 1500 + // tun, err := wg_tun.CreateTUN(interfaceName, mtu) + tun, err := wireguard.CreateWgTun(s, ep) + if err != nil { + log.Fatal("CreateWgTun", err) + } + + + fileUAPI, err := func() (*os.File, error) { + ENV_WG_UAPI_FD := "WG_UAPI_FD" + uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) + if uapiFdStr == "" { + return ipc.UAPIOpen(interfaceName) + } + + // use supplied fd + + fd, err := strconv.ParseUint(uapiFdStr, 10, 32) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(fd), ""), nil + }() + + if err != nil { + logger.Error.Println("UAPI listen error:", err) + log.Fatal("Setup failed") + } + // daemonize the process + + device := device.NewDevice(tun, logger) + + errs := make(chan error) + + uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) + if err != nil { + logger.Error.Println("Failed to listen on uapi socket:", err) + log.Fatal("Setup failed") + } + + go func() { + for { + conn, err := uapi.Accept() + if err != nil { + errs <- err + return + } + go device.IpcHandle(conn) + } + }() + + logger.Info.Println("UAPI listener started") + fmt.Println("Device ", device) + + if err := s.CreateNICWithOptions(nic, ep, stack.NICOptions{Name: interfaceName, Disabled: true}); err != nil { + log.Fatal("CreateNIC", err) + } + + return device +} + +func parseAddress(addrName string) (tcpip.Address, tcpip.NetworkProtocolNumber) { + ip := net.ParseIP(addrName) + + if ip.To4() != nil { + return tcpip.Address(ip.To4()), ipv4.ProtocolNumber + } else { + return tcpip.Address(ip.To16()), ipv6.ProtocolNumber + } +} + +func ParseSubnet(subnetName string) (tcpip.Address, tcpip.Subnet, tcpip.NetworkProtocolNumber) { + parsedAddr, parsedNet, err := net.ParseCIDR(subnetName) + if err != nil { + log.Fatalf("Bad IP/CIDR address: %v", subnetName) + } + + var addr tcpip.Address + var net tcpip.Address + var proto tcpip.NetworkProtocolNumber + + if parsedAddr.To4() != nil { + addr = tcpip.Address(parsedAddr.To4()) + net = tcpip.Address(parsedNet.IP.To4()) + proto = ipv4.ProtocolNumber + } else { + addr = tcpip.Address(parsedAddr.To16()) + net = tcpip.Address(parsedNet.IP.To16()) + proto = ipv6.ProtocolNumber + } + + // ones, zeros := parsedNet.Mask.Size() + + mask, err := hex.DecodeString(parsedNet.Mask.String()) + if err != nil { + log.Fatalf("Bad mask", err) + } + + subnet, err := tcpip.NewSubnet(net, tcpip.AddressMask(mask)) + if err != nil { + log.Fatalf("Bad subnet", err, net, parsedNet.Mask.String()) + } + + return addr, subnet, proto +} + +func (routes *Routes) AddAddress(s *stack.Stack, nic tcpip.NICID, addrName string) tcpip.NetworkProtocolNumber { + // // Parse the IP address. Support both ipv4 and ipv6. + addr, subnet, proto := ParseSubnet(addrName) + + if false { + if err := s.AddAddress(nic, proto, addr); err != nil { + log.Fatal("AddAddress", err) + } + } else { + addr := tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: subnet.Prefix(), + }} + + fmt.Println("Added address ", addr) + + if err := s.AddProtocolAddress(nic, addr); err != nil { + log.Fatalf("AddProtocolAddress", err, nic) + } + + route := tcpip.Route{ + Destination: addr.AddressWithPrefix.Subnet(), + NIC: nic, + } + + *routes = append(*routes, route) + } + + // subnet, err := tcpip.NewSubnet(tcpip.Address(parsedNet.IP), + // tcpip.AddressMask(parsedNet.Mask)) + // if err != nil { + // log.Fatal(err) + // } + + return proto +} + +func (routes *Routes) addRoute(nic tcpip.NICID, routeCfg Route){ + _, dest, _ := ParseSubnet(routeCfg.To) + via, _ := parseAddress(routeCfg.Via) + + route := tcpip.Route{ + Destination: dest, + Gateway: via, + NIC: nic, + Mark: routeCfg.Mark, + Markmask: routeCfg.Markmask, + } + + *routes = append(*routes, route) +} + +func (routes *Routes) setupCommon(s *stack.Stack, nic tcpip.NICID, id string, cfg Common) { + for _, addr := range cfg.Addresses { + routes.AddAddress(s, nic, addr) + } + + for _, route := range cfg.Routes { + fmt.Println("Add Route:", route) + routes.addRoute(nic, route) + } + + for _, route := range *routes { + fmt.Println("Added Route:", route) + } + + // TODO check after enabling the NICs + // if !s.CheckNIC(nic) { + // log.Fatal("not usable ", id) + // } +} + +func (routes *Routes) setupLoopback(s *stack.Stack, nic tcpip.NICID, id string, eth *Ethernet) { + fmt.Println("Ethernet", id, nic, eth) + + linkEP := loopback.New() + if err := s.CreateNICWithOptions(nic, linkEP, stack.NICOptions{Name: id, Disabled: true}); err != nil { + log.Fatal("CreateNIC", err) + } + + routes.setupCommon(s, nic, id, eth.Common) +} + +func (routes *Routes) setupTunnel(s *stack.Stack, nic tcpip.NICID, id string, tun *Tunnel) { + fmt.Println("TUN", id, nic, tun) + + maddr, err := net.ParseMAC(tun.Macaddress) + if err != nil { + log.Fatalf("Bad MAC address: %v", tun.Macaddress) + } + + addRouterLink(s, nic, id, tcpip.LinkAddress(maddr), tun) + fmt.Println("Tunnel 20", tun.Conn) + routes.setupCommon(s, nic, id, tun.Common) + fmt.Println("Tunnel 21", tun.Conn) +} + +func (routes *Routes) setupTuntap(s *stack.Stack, nic tcpip.NICID, id string, tun *Tuntap) { + fmt.Println("Tuntap", id, nic, tun) + + maddr, err := net.ParseMAC(tun.Macaddress) + if err != nil { + log.Fatalf("Bad MAC address: %v", tun.Macaddress) + } + + var tap bool + switch tun.Mode { + case "tun": + tap = false + case "tap": + tap = true + default: + log.Fatalf("Bad mode: %v", tun.Mode) + } + + addTunLink(s, nic, id, tap, tcpip.LinkAddress(maddr), tun) + routes.setupCommon(s, nic, id, tun.Common) +} + +func (routes *Routes) setupWG(s *stack.Stack, nic tcpip.NICID, id string, wg *Wireguard) { + fmt.Println("WG", id, nic, wg.ListenPort, wg) + fmt.Printf("Peers %v\n", wg.Peers) + + maddr, err := net.ParseMAC(wg.Macaddress) + if err != nil { + log.Fatalf("Bad MAC address: %v", wg.Macaddress) + } + + //addTunLink(s, tunNic, tunName, tcpip.LinkAddress(tapMaddr)) + device := addWgLink(s, nic, wg.Name, tcpip.LinkAddress(maddr)) + + var wgCmd strings.Builder + fmt.Fprintf(&wgCmd, "private_key=%s\nlisten_port=%d\nreplace_peers=true\n", + hex.EncodeToString(wg.PrivateKey), wg.ListenPort) + for _, peer := range wg.Peers { + fmt.Fprintf(&wgCmd, "public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\nreplace_allowed_ips=true\n", + hex.EncodeToString(peer.PublicKey), peer.Endpoint, peer.PersistentKeepalive) + for _, allowedIp := range peer.AllowedIPs { + fmt.Fprintf(&wgCmd, "allowed_ip=%s\n", allowedIp) + } + wgCmd.WriteString("\n") + } + str := wgCmd.String() + fmt.Println("IpcSetOperation", str) + device.IpcSetOperation(bufio.NewReader(strings.NewReader(str))) + + routes.setupCommon(s, nic, id, wg.Common) + + go func() { + fmt.Println("Starting ", nic) + + select { + // case <-term: + // case <-errs: + case <-device.Wait(): + } + + fmt.Println("Finnished ", nic) + }() +} + +func (routes *Routes) Setup(s *stack.Stack, np *Netplan) { + s.SetForwarding(true) + + var nic tcpip.NICID = -1 +// var wg2Nic tcpip.NICID = -1 + + for id, tun := range np.Network.Ethernets { + nic = nic + 1 + routes.setupLoopback(s, nic, id, tun) + } + + for id, tun := range np.Network.Tuntaps { + nic = nic + 1 + routes.setupTuntap(s, nic, id, tun) + } + + for id, wg := range np.Network.Wireguards { + nic = nic + 1 + // if id == "wg2" { + // wg2Nic = nic + // } + routes.setupWG(s, nic, id, wg) + } + + for id, tun := range np.Network.Tunnels { + nic = nic + 1 + routes.setupTunnel(s, nic, id, tun) + } + + nicCount := nic + + for nic = 0; nic < nicCount; nic++ { + s.EnableNIC(nic) + } +} |