diff options
Diffstat (limited to 'pkg/tcpip/link/tunnel/gre.go')
-rw-r--r-- | pkg/tcpip/link/tunnel/gre.go | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/pkg/tcpip/link/tunnel/gre.go b/pkg/tcpip/link/tunnel/gre.go new file mode 100644 index 000000000..8f6f1649a --- /dev/null +++ b/pkg/tcpip/link/tunnel/gre.go @@ -0,0 +1,153 @@ +package tunnel + +import ( + "context" + "log" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/gre" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" + "gvisor.dev/gvisor/pkg/waiter" +) + +var SIZE = 16 + +type Endpoint struct { + channel.Endpoint +} + +type writer Endpoint + +func New(mtu uint32) *Endpoint { + var linkAddress tcpip.LinkAddress + + ch := channel.New(SIZE, mtu, linkAddress) + return &Endpoint{ + Endpoint: *ch, + } +} + +func (e *Endpoint) GetChannel() *channel.Endpoint { + return &e.Endpoint +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + log.Println("GRE: Attach") + + e.Endpoint.Attach(dispatcher) +} + +type GrePacketInfo struct { + pi *channel.PacketInfo +} + +func (g GrePacketInfo) FullPayload() ([]byte, *tcpip.Error){ + log.Println("FullPayload") + pkt := g.pi.Pkt + size := pkt.Size() + vv := buffer.NewVectorisedView(size, pkt.Views()) + var buf = make([]byte, size, size) + pos := 0 + for { + copied, err := vv.Read(buf[pos:]) + log.Printf("VectorisedView Read: %d %d %d %v", size, pos, copied, err) + if err != nil { + return nil, tcpip.ErrBadBuffer + } + pos = pos + copied + if pos == size { + break + } + } + + log.Printf("FullPayload return: %d %v", len(buf), buf) + return buf, nil +} + +func (g GrePacketInfo) Payload(size int) ([]byte, *tcpip.Error){ + log.Println("Payload") + return nil, tcpip.ErrNotSupported +} + +type GreHandlerInfo struct { + ChEp *channel.Endpoint + Raw tcpip.Endpoint + Raddr tcpip.FullAddress +} + +func (info *GreHandlerInfo) greHandler(req *gre.ForwarderRequest) { + pkt := req.Pkt + log.Println("greHandler: ", req.Pkt.Size(), req.Pkt.Views()) + greHdr := header.GRE(pkt.TransportHeader().View()) + proto := greHdr.ProtocolType() + views := pkt.Data.Views() + size := pkt.Data.Size() + data := buffer.NewVectorisedView(size, views) + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, + }) + log.Printf("greHandler proto: %d cloned: %v", proto, newPkt.Views()) + + info.ChEp.InjectInbound(proto, newPkt) +} + +func (info *GreHandlerInfo) greRead(ep *channel.Endpoint) { + for { + pi, err := ep.ReadContext(context.Background()); + linkHdr := pi.Pkt.LinkHeader() + greHdr := header.GRE(linkHdr.Push(header.GREHeaderSize)) + greHdr.SetProtocolType(pi.Proto) + log.Printf("greRead %d %v %v %v", pi.Proto, pi, err, greHdr) + opts := tcpip.WriteOptions{ + //To: &info.Raddr + } + info.Raw.Write(GrePacketInfo{&pi}, opts) + } +} + +func (e *Endpoint) Start(s *stack.Stack, laddr, raddr *tcpip.Address) { + // Create TCP endpoint. + var rawWq waiter.Queue + rawEp, tcperr := raw.NewEndpoint(s, ipv4.ProtocolNumber, header.GREProtocolNumber, &rawWq) + if tcperr != nil { + log.Fatal(tcperr) + } + log.Println("EP: %s", rawEp) + + fraddr := tcpip.FullAddress{NIC: 1, Addr: *raddr} + flAddr := tcpip.FullAddress{NIC: 1, Addr: *laddr} + tcperr = rawEp.Bind(flAddr) + if tcperr != nil { + log.Fatal(tcperr) + } + log.Printf("Remote: %v %v", raddr, fraddr) + + // Create GRE + // greEP := grelink.New(mtu - 24) + chEP := e.GetChannel() + // TODO detect IPv4/IPv6 + e.NMaxHeaderLength = header.IPv6FixedHeaderSize + header.GREHeaderSize + + greInfo := GreHandlerInfo{ +// Ep: loEP, + ChEp: chEP, + Raw: rawEp, + Raddr: fraddr, + } + greFwd := gre.NewForwarder(s, greInfo.greHandler) + s.SetTransportProtocolHandler(header.GREProtocolNumber, greFwd.HandlePacket) + + go greInfo.greRead(chEP) + + tcperr = rawEp.Connect(fraddr) + if tcperr != nil { + log.Fatal(tcperr) + } +} |