From d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 Mon Sep 17 00:00:00 2001 From: Googler Date: Fri, 27 Apr 2018 10:37:02 -0700 Subject: Check in gVisor. PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463 --- pkg/tcpip/link/channel/BUILD | 15 + pkg/tcpip/link/channel/channel.go | 110 +++++ pkg/tcpip/link/fdbased/BUILD | 32 ++ pkg/tcpip/link/fdbased/endpoint.go | 261 ++++++++++ pkg/tcpip/link/fdbased/endpoint_test.go | 336 +++++++++++++ pkg/tcpip/link/loopback/BUILD | 15 + pkg/tcpip/link/loopback/loopback.go | 74 +++ pkg/tcpip/link/rawfile/BUILD | 17 + pkg/tcpip/link/rawfile/blockingpoll_amd64.s | 26 + pkg/tcpip/link/rawfile/errors.go | 40 ++ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 161 ++++++ pkg/tcpip/link/sharedmem/BUILD | 42 ++ pkg/tcpip/link/sharedmem/pipe/BUILD | 23 + pkg/tcpip/link/sharedmem/pipe/pipe.go | 68 +++ pkg/tcpip/link/sharedmem/pipe/pipe_test.go | 507 +++++++++++++++++++ pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go | 25 + pkg/tcpip/link/sharedmem/pipe/rx.go | 83 ++++ pkg/tcpip/link/sharedmem/pipe/tx.go | 151 ++++++ pkg/tcpip/link/sharedmem/queue/BUILD | 28 ++ pkg/tcpip/link/sharedmem/queue/queue_test.go | 507 +++++++++++++++++++ pkg/tcpip/link/sharedmem/queue/rx.go | 211 ++++++++ pkg/tcpip/link/sharedmem/queue/tx.go | 141 ++++++ pkg/tcpip/link/sharedmem/rx.go | 147 ++++++ pkg/tcpip/link/sharedmem/sharedmem.go | 240 +++++++++ pkg/tcpip/link/sharedmem/sharedmem_test.go | 703 +++++++++++++++++++++++++++ pkg/tcpip/link/sharedmem/sharedmem_unsafe.go | 15 + pkg/tcpip/link/sharedmem/tx.go | 262 ++++++++++ pkg/tcpip/link/sniffer/BUILD | 23 + pkg/tcpip/link/sniffer/pcap.go | 52 ++ pkg/tcpip/link/sniffer/sniffer.go | 310 ++++++++++++ pkg/tcpip/link/tun/BUILD | 12 + pkg/tcpip/link/tun/tun_unsafe.go | 50 ++ pkg/tcpip/link/waitable/BUILD | 33 ++ pkg/tcpip/link/waitable/waitable.go | 108 ++++ pkg/tcpip/link/waitable/waitable_test.go | 144 ++++++ 35 files changed, 4972 insertions(+) create mode 100644 pkg/tcpip/link/channel/BUILD create mode 100644 pkg/tcpip/link/channel/channel.go create mode 100644 pkg/tcpip/link/fdbased/BUILD create mode 100644 pkg/tcpip/link/fdbased/endpoint.go create mode 100644 pkg/tcpip/link/fdbased/endpoint_test.go create mode 100644 pkg/tcpip/link/loopback/BUILD create mode 100644 pkg/tcpip/link/loopback/loopback.go create mode 100644 pkg/tcpip/link/rawfile/BUILD create mode 100644 pkg/tcpip/link/rawfile/blockingpoll_amd64.s create mode 100644 pkg/tcpip/link/rawfile/errors.go create mode 100644 pkg/tcpip/link/rawfile/rawfile_unsafe.go create mode 100644 pkg/tcpip/link/sharedmem/BUILD create mode 100644 pkg/tcpip/link/sharedmem/pipe/BUILD create mode 100644 pkg/tcpip/link/sharedmem/pipe/pipe.go create mode 100644 pkg/tcpip/link/sharedmem/pipe/pipe_test.go create mode 100644 pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go create mode 100644 pkg/tcpip/link/sharedmem/pipe/rx.go create mode 100644 pkg/tcpip/link/sharedmem/pipe/tx.go create mode 100644 pkg/tcpip/link/sharedmem/queue/BUILD create mode 100644 pkg/tcpip/link/sharedmem/queue/queue_test.go create mode 100644 pkg/tcpip/link/sharedmem/queue/rx.go create mode 100644 pkg/tcpip/link/sharedmem/queue/tx.go create mode 100644 pkg/tcpip/link/sharedmem/rx.go create mode 100644 pkg/tcpip/link/sharedmem/sharedmem.go create mode 100644 pkg/tcpip/link/sharedmem/sharedmem_test.go create mode 100644 pkg/tcpip/link/sharedmem/sharedmem_unsafe.go create mode 100644 pkg/tcpip/link/sharedmem/tx.go create mode 100644 pkg/tcpip/link/sniffer/BUILD create mode 100644 pkg/tcpip/link/sniffer/pcap.go create mode 100644 pkg/tcpip/link/sniffer/sniffer.go create mode 100644 pkg/tcpip/link/tun/BUILD create mode 100644 pkg/tcpip/link/tun/tun_unsafe.go create mode 100644 pkg/tcpip/link/waitable/BUILD create mode 100644 pkg/tcpip/link/waitable/waitable.go create mode 100644 pkg/tcpip/link/waitable/waitable_test.go (limited to 'pkg/tcpip/link') diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD new file mode 100644 index 000000000..b58a76699 --- /dev/null +++ b/pkg/tcpip/link/channel/BUILD @@ -0,0 +1,15 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "channel", + srcs = ["channel.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel", + visibility = ["//:sandbox"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go new file mode 100644 index 000000000..cebc34553 --- /dev/null +++ b/pkg/tcpip/link/channel/channel.go @@ -0,0 +1,110 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package channel provides the implemention of channel-based data-link layer +// endpoints. Such endpoints allow injection of inbound packets and store +// outbound packets in a channel. +package channel + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// PacketInfo holds all the information about an outbound packet. +type PacketInfo struct { + Header buffer.View + Payload buffer.View + Proto tcpip.NetworkProtocolNumber +} + +// Endpoint is link layer endpoint that stores outbound packets in a channel +// and allows injection of inbound packets. +type Endpoint struct { + dispatcher stack.NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan PacketInfo +} + +// New creates a new channel endpoint. +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { + e := &Endpoint{ + C: make(chan PacketInfo, size), + mtu: mtu, + linkAddr: linkAddr, + } + + return stack.RegisterLinkEndpoint(e), e +} + +// Drain removes all outbound packets from the channel and counts them. +func (e *Endpoint) Drain() int { + c := 0 + for { + select { + case <-e.C: + c++ + default: + return c + } + } +} + +// Inject injects an inbound packet. +func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + uu := vv.Clone(nil) + e.dispatcher.DeliverNetworkPacket(e, "", protocol, &uu) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *Endpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*Endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +// WritePacket stores outbound packets into the channel. +func (e *Endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + p := PacketInfo{ + Header: hdr.View(), + Proto: protocol, + } + + if payload != nil { + p.Payload = make(buffer.View, len(payload)) + copy(p.Payload, payload) + } + + select { + case e.C <- p: + default: + } + + return nil +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD new file mode 100644 index 000000000..b5ab1ea6a --- /dev/null +++ b/pkg/tcpip/link/fdbased/BUILD @@ -0,0 +1,32 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "fdbased", + srcs = ["endpoint.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/fdbased", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/rawfile", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "fdbased_test", + size = "small", + srcs = ["endpoint_test.go"], + embed = [":fdbased"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go new file mode 100644 index 000000000..da74cd644 --- /dev/null +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -0,0 +1,261 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fdbased provides the implemention of data-link layer endpoints +// backed by boundary-preserving file descriptors (e.g., TUN devices, +// seqpacket/datagram sockets). +// +// FD based endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package fdbased + +import ( + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// BufConfig defines the shape of the vectorised view used to read packets from the NIC. +var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} + +type endpoint struct { + // fd is the file descriptor used to send and receive packets. + fd int + + // mtu (maximum transmission unit) is the maximum size of a packet. + mtu uint32 + + // hdrSize specifies the link-layer header size. If set to 0, no header + // is added/removed; otherwise an ethernet header is used. + hdrSize int + + // addr is the address of the endpoint. + addr tcpip.LinkAddress + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // closed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + closed func(*tcpip.Error) + + vv *buffer.VectorisedView + iovecs []syscall.Iovec + views []buffer.View +} + +// Options specify the details about the fd-based endpoint to be created. +type Options struct { + FD int + MTU uint32 + EthernetHeader bool + ChecksumOffload bool + ClosedFunc func(*tcpip.Error) + Address tcpip.LinkAddress +} + +// New creates a new fd-based endpoint. +// +// Makes fd non-blocking, but does not take ownership of fd, which must remain +// open for the lifetime of the returned endpoint. +func New(opts *Options) tcpip.LinkEndpointID { + syscall.SetNonblock(opts.FD, true) + + caps := stack.LinkEndpointCapabilities(0) + if opts.ChecksumOffload { + caps |= stack.CapabilityChecksumOffload + } + + hdrSize := 0 + if opts.EthernetHeader { + hdrSize = header.EthernetMinimumSize + caps |= stack.CapabilityResolutionRequired + } + + e := &endpoint{ + fd: opts.FD, + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: hdrSize, + views: make([]buffer.View, len(BufConfig)), + iovecs: make([]syscall.Iovec, len(BufConfig)), + } + vv := buffer.NewVectorisedView(0, e.views) + e.vv = &vv + return stack.RegisterLinkEndpoint(e) +} + +// Attach launches the goroutine that reads packets from the file descriptor and +// dispatches them via the provided dispatcher. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + go e.dispatchLoop(dispatcher) // S/R-FIXME +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *endpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength returns the maximum size of the link-layer header. +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress returns the link address of this endpoint. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + if e.hdrSize > 0 { + // Add ethernet header if needed. + eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth.Encode(&header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + SrcAddr: e.addr, + Type: protocol, + }) + } + + if len(payload) == 0 { + return rawfile.NonBlockingWrite(e.fd, hdr.UsedBytes()) + + } + + return rawfile.NonBlockingWrite2(e.fd, hdr.UsedBytes(), payload) +} + +func (e *endpoint) capViews(n int, buffers []int) int { + c := 0 + for i, s := range buffers { + c += s + if c >= n { + e.views[i].CapLength(s - (c - n)) + return i + 1 + } + } + return len(buffers) +} + +func (e *endpoint) allocateViews(bufConfig []int) { + for i, v := range e.views { + if v != nil { + break + } + b := buffer.NewView(bufConfig[i]) + e.views[i] = b + e.iovecs[i] = syscall.Iovec{ + Base: &b[0], + Len: uint64(len(b)), + } + } +} + +// dispatch reads one packet from the file descriptor and dispatches it. +func (e *endpoint) dispatch(d stack.NetworkDispatcher, largeV buffer.View) (bool, *tcpip.Error) { + e.allocateViews(BufConfig) + + n, err := rawfile.BlockingReadv(e.fd, e.iovecs) + if err != nil { + return false, err + } + + if n <= e.hdrSize { + return false, nil + } + + var p tcpip.NetworkProtocolNumber + var addr tcpip.LinkAddress + if e.hdrSize > 0 { + eth := header.Ethernet(e.views[0]) + p = eth.Type() + addr = eth.SourceAddress() + } else { + // We don't get any indication of what the packet is, so try to guess + // if it's an IPv4 or IPv6 packet. + switch header.IPVersion(e.views[0]) { + case header.IPv4Version: + p = header.IPv4ProtocolNumber + case header.IPv6Version: + p = header.IPv6ProtocolNumber + default: + return true, nil + } + } + + used := e.capViews(n, BufConfig) + e.vv.SetViews(e.views[:used]) + e.vv.SetSize(n) + e.vv.TrimFront(e.hdrSize) + + d.DeliverNetworkPacket(e, addr, p, e.vv) + + // Prepare e.views for another packet: release used views. + for i := 0; i < used; i++ { + e.views[i] = nil + } + + return true, nil +} + +// dispatchLoop reads packets from the file descriptor in a loop and dispatches +// them to the network stack. +func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) *tcpip.Error { + v := buffer.NewView(header.MaxIPPacketSize) + for { + cont, err := e.dispatch(d, v) + if err != nil || !cont { + if e.closed != nil { + e.closed(err) + } + return err + } + } +} + +// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes +// to the FD, but does not read from it. All reads come from injected packets. +type InjectableEndpoint struct { + endpoint + + dispatcher stack.NetworkDispatcher +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// Inject injects an inbound packet. +func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + e.dispatcher.DeliverNetworkPacket(e, "", protocol, vv) +} + +// NewInjectable creates a new fd-based InjectableEndpoint. +func NewInjectable(fd int, mtu uint32) (tcpip.LinkEndpointID, *InjectableEndpoint) { + syscall.SetNonblock(fd, true) + + e := &InjectableEndpoint{endpoint: endpoint{ + fd: fd, + mtu: mtu, + }} + + return stack.RegisterLinkEndpoint(e), e +} diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go new file mode 100644 index 000000000..f7bbb28e1 --- /dev/null +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -0,0 +1,336 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fdbased + +import ( + "fmt" + "math/rand" + "reflect" + "syscall" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +type packetInfo struct { + raddr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + contents buffer.View +} + +type context struct { + t *testing.T + fds [2]int + ep stack.LinkEndpoint + ch chan packetInfo + done chan struct{} +} + +func newContext(t *testing.T, opt *Options) *context { + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + t.Fatalf("Socketpair failed: %v", err) + } + + done := make(chan struct{}, 1) + opt.ClosedFunc = func(*tcpip.Error) { + done <- struct{}{} + } + + opt.FD = fds[1] + ep := stack.FindLinkEndpoint(New(opt)).(*endpoint) + + c := &context{ + t: t, + fds: fds, + ep: ep, + ch: make(chan packetInfo, 100), + done: done, + } + + ep.Attach(c) + + return c +} + +func (c *context) cleanup() { + syscall.Close(c.fds[0]) + <-c.done + syscall.Close(c.fds[1]) +} + +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()} +} + +func TestNoEthernetProperties(t *testing.T) { + const mtu = 1500 + c := newContext(t, &Options{MTU: mtu}) + defer c.cleanup() + + if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v { + t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) + } + + if want, v := uint32(mtu), c.ep.MTU(); want != v { + t.Fatalf("MTU() = %v, want %v", v, want) + } +} + +func TestEthernetProperties(t *testing.T) { + const mtu = 1500 + c := newContext(t, &Options{EthernetHeader: true, MTU: mtu}) + defer c.cleanup() + + if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { + t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) + } + + if want, v := uint32(mtu), c.ep.MTU(); want != v { + t.Fatalf("MTU() = %v, want %v", v, want) + } +} + +func TestAddress(t *testing.T) { + const mtu = 1500 + addrs := []tcpip.LinkAddress{"", "abc", "def"} + for _, a := range addrs { + t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { + c := newContext(t, &Options{Address: a, MTU: mtu}) + defer c.cleanup() + + if want, v := a, c.ep.LinkAddress(); want != v { + t.Fatalf("LinkAddress() = %v, want %v", v, want) + } + }) + } +} + +func TestWritePacket(t *testing.T) { + const ( + mtu = 1500 + laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") + raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") + proto = 10 + ) + + lengths := []int{0, 100, 1000} + eths := []bool{true, false} + + for _, eth := range eths { + for _, plen := range lengths { + t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { + c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) + defer c.cleanup() + + r := &stack.Route{ + RemoteLinkAddress: raddr, + } + + // Build header. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) + b := hdr.Prepend(100) + for i := range b { + b[i] = uint8(rand.Intn(256)) + } + + // Buiild payload and write. + payload := make([]byte, plen) + for i := range payload { + payload[i] = uint8(rand.Intn(256)) + } + want := append(hdr.UsedBytes(), payload...) + if err := c.ep.WritePacket(r, &hdr, payload, proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } + + // Read from fd, then compare with what we wrote. + b = make([]byte, mtu) + n, err := syscall.Read(c.fds[0], b) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + b = b[:n] + if eth { + h := header.Ethernet(b) + b = b[header.EthernetMinimumSize:] + + if a := h.SourceAddress(); a != laddr { + t.Fatalf("SourceAddress() = %v, want %v", a, laddr) + } + + if a := h.DestinationAddress(); a != raddr { + t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) + } + + if et := h.Type(); et != proto { + t.Fatalf("Type() = %v, want %v", et, proto) + } + } + if len(b) != len(want) { + t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) + } + if !reflect.DeepEqual(b, want) { + t.Fatalf("Read returned %x, want %x", b, want) + } + }) + } + } +} + +func TestDeliverPacket(t *testing.T) { + const ( + mtu = 1500 + laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") + raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") + proto = 10 + ) + + lengths := []int{100, 1000} + eths := []bool{true, false} + + for _, eth := range eths { + for _, plen := range lengths { + t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) { + c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth}) + defer c.cleanup() + + // Build packet. + b := make([]byte, plen) + all := b + for i := range b { + b[i] = uint8(rand.Intn(256)) + } + + if !eth { + // So that it looks like an IPv4 packet. + b[0] = 0x40 + } else { + hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr.Encode(&header.EthernetFields{ + SrcAddr: raddr, + DstAddr: laddr, + Type: proto, + }) + all = append(hdr, b...) + } + + // Write packet via the file descriptor. + if _, err := syscall.Write(c.fds[0], all); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Receive packet through the endpoint. + select { + case pi := <-c.ch: + want := packetInfo{ + raddr: raddr, + proto: proto, + contents: b, + } + if !eth { + want.proto = header.IPv4ProtocolNumber + want.raddr = "" + } + if !reflect.DeepEqual(want, pi) { + t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + } + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for packet") + } + }) + } + } +} + +func TestBufConfigMaxLength(t *testing.T) { + got := 0 + for _, i := range BufConfig { + got += i + } + want := header.MaxIPPacketSize // maximum TCP packet size + if got < want { + t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) + } +} + +func TestBufConfigFirst(t *testing.T) { + // The stack assumes that the TCP/IP header is enterily contained in the first view. + // Therefore, the first view needs to be large enough to contain the maximum TCP/IP + // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). + want := 120 + got := BufConfig[0] + if got < want { + t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) + } +} + +func build(bufConfig []int) *endpoint { + e := &endpoint{ + views: make([]buffer.View, len(bufConfig)), + iovecs: make([]syscall.Iovec, len(bufConfig)), + } + e.allocateViews(bufConfig) + return e +} + +var capLengthTestCases = []struct { + comment string + config []int + n int + wantUsed int + wantLengths []int +}{ + { + comment: "Single slice", + config: []int{2}, + n: 1, + wantUsed: 1, + wantLengths: []int{1}, + }, + { + comment: "Multiple slices", + config: []int{1, 2}, + n: 2, + wantUsed: 2, + wantLengths: []int{1, 1}, + }, + { + comment: "Entire buffer", + config: []int{1, 2}, + n: 3, + wantUsed: 2, + wantLengths: []int{1, 2}, + }, + { + comment: "Entire buffer but not on the last slice", + config: []int{1, 2, 3}, + n: 3, + wantUsed: 2, + wantLengths: []int{1, 2, 3}, + }, +} + +func TestCapLength(t *testing.T) { + for _, c := range capLengthTestCases { + e := build(c.config) + used := e.capViews(c.n, c.config) + if used != c.wantUsed { + t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) + } + lengths := make([]int, len(e.views)) + for i, v := range e.views { + lengths[i] = len(v) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) + } + + } +} diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD new file mode 100644 index 000000000..b454d0839 --- /dev/null +++ b/pkg/tcpip/link/loopback/BUILD @@ -0,0 +1,15 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "loopback", + srcs = ["loopback.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback", + visibility = ["//:sandbox"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go new file mode 100644 index 000000000..1a9cd09d7 --- /dev/null +++ b/pkg/tcpip/link/loopback/loopback.go @@ -0,0 +1,74 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package loopback provides the implemention of loopback data-link layer +// endpoints. Such endpoints just turn outbound packets into inbound ones. +// +// Loopback endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package loopback + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + dispatcher stack.NetworkDispatcher +} + +// New creates a new loopback endpoint. This link-layer endpoint just turns +// outbound packets into inbound packets. +func New() tcpip.LinkEndpointID { + return stack.RegisterLinkEndpoint(&endpoint{}) +} + +// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- +// layer dispatcher for later use when packets need to be dispatched. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the +// linux loopback interface. +func (*endpoint) MTU() uint32 { + return 65536 +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises +// itself as supporting checksum offload, but in reality it's just omitted. +func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityChecksumOffload +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the +// loopback interface doesn't have a header, it just returns 0. +func (*endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*endpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound +// packets to the network-layer dispatcher. +func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + if len(payload) == 0 { + // We don't have a payload, so just use the buffer from the + // header as the full packet. + v := hdr.View() + vv := v.ToVectorisedView([1]buffer.View{}) + e.dispatcher.DeliverNetworkPacket(e, "", protocol, &vv) + } else { + views := []buffer.View{hdr.View(), payload} + vv := buffer.NewVectorisedView(len(views[0])+len(views[1]), views) + e.dispatcher.DeliverNetworkPacket(e, "", protocol, &vv) + } + + return nil +} diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD new file mode 100644 index 000000000..4c63af0ea --- /dev/null +++ b/pkg/tcpip/link/rawfile/BUILD @@ -0,0 +1,17 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "rawfile", + srcs = [ + "blockingpoll_amd64.s", + "errors.go", + "rawfile_unsafe.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile", + visibility = [ + "//visibility:public", + ], + deps = ["//pkg/tcpip"], +) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s new file mode 100644 index 000000000..88206fc87 --- /dev/null +++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s @@ -0,0 +1,26 @@ +#include "textflag.h" + +// blockingPoll makes the poll() syscall while calling the version of +// entersyscall that relinquishes the P so that other Gs can run. This is meant +// to be called in cases when the syscall is expected to block. +// +// func blockingPoll(fds unsafe.Pointer, nfds int, timeout int64) (n int, err syscall.Errno) +TEXT ·blockingPoll(SB),NOSPLIT,$0-40 + CALL runtime·entersyscallblock(SB) + MOVQ fds+0(FP), DI + MOVQ nfds+8(FP), SI + MOVQ timeout+16(FP), DX + MOVQ $0x7, AX // SYS_POLL + SYSCALL + CMPQ AX, $0xfffffffffffff001 + JLS ok + MOVQ $-1, n+24(FP) + NEGQ AX + MOVQ AX, err+32(FP) + CALL runtime·exitsyscall(SB) + RET +ok: + MOVQ AX, n+24(FP) + MOVQ $0, err+32(FP) + CALL runtime·exitsyscall(SB) + RET diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go new file mode 100644 index 000000000..b6e7b3d71 --- /dev/null +++ b/pkg/tcpip/link/rawfile/errors.go @@ -0,0 +1,40 @@ +package rawfile + +import ( + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" +) + +var translations = map[syscall.Errno]*tcpip.Error{ + syscall.EEXIST: tcpip.ErrDuplicateAddress, + syscall.ENETUNREACH: tcpip.ErrNoRoute, + syscall.EINVAL: tcpip.ErrInvalidEndpointState, + syscall.EALREADY: tcpip.ErrAlreadyConnecting, + syscall.EISCONN: tcpip.ErrAlreadyConnected, + syscall.EADDRINUSE: tcpip.ErrPortInUse, + syscall.EADDRNOTAVAIL: tcpip.ErrBadLocalAddress, + syscall.EPIPE: tcpip.ErrClosedForSend, + syscall.EWOULDBLOCK: tcpip.ErrWouldBlock, + syscall.ECONNREFUSED: tcpip.ErrConnectionRefused, + syscall.ETIMEDOUT: tcpip.ErrTimeout, + syscall.EINPROGRESS: tcpip.ErrConnectStarted, + syscall.EDESTADDRREQ: tcpip.ErrDestinationRequired, + syscall.ENOTSUP: tcpip.ErrNotSupported, + syscall.ENOTTY: tcpip.ErrQueueSizeNotSupported, + syscall.ENOTCONN: tcpip.ErrNotConnected, + syscall.ECONNRESET: tcpip.ErrConnectionReset, + syscall.ECONNABORTED: tcpip.ErrConnectionAborted, +} + +// TranslateErrno translate an errno from the syscall package into a +// *tcpip.Error. +// +// Not all errnos are supported and this function will panic on unreconized +// errnos. +func TranslateErrno(e syscall.Errno) *tcpip.Error { + if err, ok := translations[e]; ok { + return err + } + return tcpip.ErrInvalidEndpointState +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go new file mode 100644 index 000000000..d3660e1b4 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -0,0 +1,161 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rawfile contains utilities for using the netstack with raw host +// files on Linux hosts. +package rawfile + +import ( + "syscall" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" +) + +//go:noescape +func blockingPoll(fds unsafe.Pointer, nfds int, timeout int64) (n int, err syscall.Errno) + +// GetMTU determines the MTU of a network interface device. +func GetMTU(name string) (uint32, error) { + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + return 0, err + } + + defer syscall.Close(fd) + + var ifreq struct { + name [16]byte + mtu int32 + _ [20]byte + } + + copy(ifreq.name[:], name) + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq))) + if errno != 0 { + return 0, errno + } + + return uint32(ifreq.mtu), nil +} + +// NonBlockingWrite writes the given buffer to a file descriptor. It fails if +// partial data is written. +func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { + var ptr unsafe.Pointer + if len(buf) > 0 { + ptr = unsafe.Pointer(&buf[0]) + } + + _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf))) + if e != 0 { + return TranslateErrno(e) + } + + return nil +} + +// NonBlockingWrite2 writes up to two byte slices to a file descriptor in a +// single syscall. It fails if partial data is written. +func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error { + // If the is no second buffer, issue a regular write. + if len(b2) == 0 { + return NonBlockingWrite(fd, b1) + } + + // We have two buffers. Build the iovec that represents them and issue + // a writev syscall. + iovec := [...]syscall.Iovec{ + { + Base: &b1[0], + Len: uint64(len(b1)), + }, + { + Base: &b2[0], + Len: uint64(len(b2)), + }, + } + + _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if e != 0 { + return TranslateErrno(e) + } + + return nil +} + +// NonBlockingWriteN writes up to N byte slices to a file descriptor in a +// single syscall. It fails if partial data is written. +func NonBlockingWriteN(fd int, bs ...[]byte) *tcpip.Error { + iovec := make([]syscall.Iovec, 0, len(bs)) + + for _, b := range bs { + if len(b) == 0 { + continue + } + iovec = append(iovec, syscall.Iovec{ + Base: &b[0], + Len: uint64(len(b)), + }) + } + + _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if e != 0 { + return TranslateErrno(e) + } + + return nil +} + +// BlockingRead reads from a file descriptor that is set up as non-blocking. If +// no data is available, it will block in a poll() syscall until the file +// descirptor becomes readable. +func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { + for { + n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) + if e == 0 { + return int(n), nil + } + + event := struct { + fd int32 + events int16 + revents int16 + }{ + fd: int32(fd), + events: 1, // POLLIN + } + + _, e = blockingPoll(unsafe.Pointer(&event), 1, -1) + if e != 0 && e != syscall.EINTR { + return 0, TranslateErrno(e) + } + } +} + +// BlockingReadv reads from a file descriptor that is set up as non-blocking and +// stores the data in a list of iovecs buffers. If no data is available, it will +// block in a poll() syscall until the file descirptor becomes readable. +func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { + for { + n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) + if e == 0 { + return int(n), nil + } + + event := struct { + fd int32 + events int16 + revents int16 + }{ + fd: int32(fd), + events: 1, // POLLIN + } + + _, e = blockingPoll(unsafe.Pointer(&event), 1, -1) + if e != 0 && e != syscall.EINTR { + return 0, TranslateErrno(e) + } + } +} diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD new file mode 100644 index 000000000..a4a965924 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -0,0 +1,42 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "sharedmem", + srcs = [ + "rx.go", + "sharedmem.go", + "sharedmem_unsafe.go", + "tx.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem", + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/log", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/rawfile", + "//pkg/tcpip/link/sharedmem/queue", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "sharedmem_test", + srcs = [ + "sharedmem_test.go", + ], + embed = [":sharedmem"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/sharedmem/pipe", + "//pkg/tcpip/link/sharedmem/queue", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD new file mode 100644 index 000000000..e8d795500 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -0,0 +1,23 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "pipe", + srcs = [ + "pipe.go", + "pipe_unsafe.go", + "rx.go", + "tx.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe", + visibility = ["//:sandbox"], +) + +go_test( + name = "pipe_test", + srcs = [ + "pipe_test.go", + ], + embed = [":pipe"], +) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go new file mode 100644 index 000000000..1173a60da --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe.go @@ -0,0 +1,68 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pipe implements a shared memory ring buffer on which a single reader +// and a single writer can operate (read/write) concurrently. The ring buffer +// allows for data of different sizes to be written, and preserves the boundary +// of the written data. +// +// Example usage is as follows: +// +// wb := t.Push(20) +// // Write data to wb. +// t.Flush() +// +// rb := r.Pull() +// // Do something with data in rb. +// t.Flush() +package pipe + +import ( + "math" +) + +const ( + jump uint64 = math.MaxUint32 + 1 + offsetMask uint64 = math.MaxUint32 + revolutionMask uint64 = ^offsetMask + + sizeOfSlotHeader = 8 // sizeof(uint64) + slotFree uint64 = 1 << 63 + slotSizeMask uint64 = math.MaxUint32 +) + +// payloadToSlotSize calculates the total size of a slot based on its payload +// size. The total size is the header size, plus the payload size, plus padding +// if necessary to make the total size a multiple of sizeOfSlotHeader. +func payloadToSlotSize(payloadSize uint64) uint64 { + s := sizeOfSlotHeader + payloadSize + return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1) +} + +// slotToPayloadSize calculates the payload size of a slot based on the total +// size of the slot. This is only meant to be used when creating slots that +// don't carry information (e.g., free slots or wrap slots). +func slotToPayloadSize(offset uint64) uint64 { + return offset - sizeOfSlotHeader +} + +// pipe is a basic data structure used by both (transmit & receive) ends of a +// pipe. Indices into this pipe are split into two fields: offset, which counts +// the number of bytes from the beginning of the buffer, and revolution, which +// counts the number of times the index has wrapped around. +type pipe struct { + buffer []byte +} + +// init initializes the pipe buffer such that its size is a multiple of the size +// of the slot header. +func (p *pipe) init(b []byte) { + p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)] +} + +// data returns a section of the buffer starting at the given index (which may +// include revolution information) and with the given size. +func (p *pipe) data(idx uint64, size uint64) []byte { + return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size] +} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go new file mode 100644 index 000000000..441ff5b25 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -0,0 +1,507 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pipe + +import ( + "math/rand" + "reflect" + "runtime" + "sync" + "testing" +) + +func TestSimpleReadWrite(t *testing.T) { + // Check that a simple write can be properly read from the rx side. + tr := rand.New(rand.NewSource(99)) + rr := rand.New(rand.NewSource(99)) + + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + wb := tx.Push(10) + if wb == nil { + t.Fatalf("Push failed on empty pipe") + } + for i := range wb { + wb[i] = byte(tr.Intn(256)) + } + tx.Flush() + + var rx Rx + rx.Init(b) + rb := rx.Pull() + if len(rb) != 10 { + t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10) + } + + for i := range rb { + if v := byte(rr.Intn(256)); v != rb[i] { + t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v) + } + } + rx.Flush() +} + +func TestEmptyRead(t *testing.T) { + // Check that pulling from an empty pipe fails. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on empty pipe") + } +} + +func TestTooLargeWrite(t *testing.T) { + // Check that writes that are too large are properly rejected. + b := make([]byte, 96) + var tx Tx + tx.Init(b) + + if wb := tx.Push(96); wb != nil { + t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe") + } + + if wb := tx.Push(88); wb != nil { + t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe") + } + + if wb := tx.Push(80); wb == nil { + t.Fatalf("Write of 80 bytes failed on 96-byte pipe") + } +} + +func TestFullWrite(t *testing.T) { + // Check that writes fail when the pipe is full. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(80); wb == nil { + t.Fatalf("Write of 80 bytes failed on 96-byte pipe") + } + + if wb := tx.Push(1); wb != nil { + t.Fatalf("Write succeeded on full pipe") + } +} + +func TestFullAndFlushedWrite(t *testing.T) { + // Check that writes fail when the pipe is full and has already been + // flushed. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(80); wb == nil { + t.Fatalf("Write of 80 bytes failed on 96-byte pipe") + } + + tx.Flush() + + if wb := tx.Push(1); wb != nil { + t.Fatalf("Write succeeded on full pipe") + } +} + +func TestTxFlushTwice(t *testing.T) { + // Checks that a second consecutive tx flush is a no-op. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + // Make copy of original tx queue, flush it, then check that it didn't + // change. + orig := tx + tx.Flush() + + if !reflect.DeepEqual(orig, tx) { + t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig) + } +} + +func TestRxFlushTwice(t *testing.T) { + // Checks that a second consecutive rx flush is a no-op. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + rx.Flush() + + // Make copy of original rx queue, flush it, then check that it didn't + // change. + orig := rx + rx.Flush() + + if !reflect.DeepEqual(orig, rx) { + t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig) + } +} + +func TestWrapInMiddleOfTransaction(t *testing.T) { + // Check that writes are not flushed when we need to wrap the buffer + // around. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + rx.Flush() + + // At this point the ring buffer is empty, but the write is at offset + // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). + if wb := tx.Push(10); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on non-full pipe") + } + + // We haven't flushed yet, so pull must return nil. + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on non-flushed pipe") + } + + tx.Flush() + + // The two buffers must be available now. + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } +} + +func TestWriteAbort(t *testing.T) { + // Check that a read fails on a pipe that has had data pushed to it but + // has aborted the push. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(10); wb == nil { + t.Fatalf("Write failed on empty pipe") + } + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on empty pipe") + } + + tx.Abort() + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on empty pipe") + } +} + +func TestWrappedWriteAbort(t *testing.T) { + // Check that writes are properly aborted even if the writes wrap + // around. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + rx.Flush() + + // At this point the ring buffer is empty, but the write is at offset + // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). + if wb := tx.Push(10); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on non-full pipe") + } + + // We haven't flushed yet, so pull must return nil. + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on non-flushed pipe") + } + + tx.Abort() + + // The pushes were aborted, so no data should be readable. + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on non-flushed pipe") + } + + // Try the same transactions again, but flush this time. + if wb := tx.Push(10); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on non-full pipe") + } + + tx.Flush() + + // The two buffers must be available now. + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } +} + +func TestEmptyReadOnNonFlushedWrite(t *testing.T) { + // Check that a read fails on a pipe that has had data pushed to it + // but not yet flushed. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(10); wb == nil { + t.Fatalf("Write failed on empty pipe") + } + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on empty pipe") + } + + tx.Flush() + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull on failed on non-empty pipe") + } +} + +func TestPullAfterPullingEntirePipe(t *testing.T) { + // Check that Pull fails when the pipe is full, but all of it has + // already been pulled but not yet flushed. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + rx.Flush() + + // At this point the ring buffer is empty, but the write is at offset + // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3 + // buffers that will fill the pipe. + if wb := tx.Push(10); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + + if wb := tx.Push(20); wb == nil { + t.Fatalf("Push failed on non-full pipe") + } + + if wb := tx.Push(24); wb == nil { + t.Fatalf("Push failed on non-full pipe") + } + + tx.Flush() + + // The three buffers must be available now. + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + + // Fourth pull must fail. + if rb := rx.Pull(); rb != nil { + t.Fatalf("Pull succeeded on empty pipe") + } +} + +func TestNoRoomToWrapOnPush(t *testing.T) { + // Check that Push fails when it tries to allocate room to add a wrap + // message. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + var rx Rx + rx.Init(b) + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + rx.Flush() + + // At this point the ring buffer is empty, but the write is at offset + // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20, + // which won't fit (64+20+8+padding = 96, which wouldn't leave room for + // the padding), so it wraps around. + if wb := tx.Push(20); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + + tx.Flush() + + // Buffer offset is at 28. Try to write 70, which would require a wrap + // slot which cannot be created now. + if wb := tx.Push(70); wb != nil { + t.Fatalf("Push succeeded on pipe with no room for wrap message") + } +} + +func TestRxImplicitFlushOfWrapMessage(t *testing.T) { + // Check if the first read is that of a wrapping message, that it gets + // immediately flushed. + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + if wb := tx.Push(50); wb == nil { + t.Fatalf("Push failed on empty pipe") + } + tx.Flush() + + // This will cause a wrapping message to written. + if wb := tx.Push(60); wb != nil { + t.Fatalf("Push succeeded when there is no room in pipe") + } + + var rx Rx + rx.Init(b) + + // Read the first message. + if rb := rx.Pull(); rb == nil { + t.Fatalf("Pull failed on non-empty pipe") + } + rx.Flush() + + // This should fail because of the wrapping message is taking up space. + if wb := tx.Push(60); wb != nil { + t.Fatalf("Push succeeded when there is no room in pipe") + } + + // Try to read the next one. This should consume the wrapping message. + rx.Pull() + + // This must now succeed. + if wb := tx.Push(60); wb == nil { + t.Fatalf("Push failed on empty pipe") + } +} + +func TestConcurrentReaderWriter(t *testing.T) { + // Push a million buffers of random sizes and random contents. Check + // that buffers read match what was written. + tr := rand.New(rand.NewSource(99)) + rr := rand.New(rand.NewSource(99)) + + b := make([]byte, 100) + var tx Tx + tx.Init(b) + + var rx Rx + rx.Init(b) + + const count = 1000000 + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + runtime.Gosched() + for i := 0; i < count; i++ { + n := 1 + tr.Intn(80) + wb := tx.Push(uint64(n)) + for wb == nil { + wb = tx.Push(uint64(n)) + } + + for j := range wb { + wb[j] = byte(tr.Intn(256)) + } + + tx.Flush() + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + runtime.Gosched() + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } + + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } + + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) + } + } + + rx.Flush() + } + }() + + wg.Wait() +} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go new file mode 100644 index 000000000..d536abedf --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go @@ -0,0 +1,25 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pipe + +import ( + "sync/atomic" + "unsafe" +) + +func (p *pipe) write(idx uint64, v uint64) { + ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) + *ptr = v +} + +func (p *pipe) writeAtomic(idx uint64, v uint64) { + ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) + atomic.StoreUint64(ptr, v) +} + +func (p *pipe) readAtomic(idx uint64) uint64 { + ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0])) + return atomic.LoadUint64(ptr) +} diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go new file mode 100644 index 000000000..261e21f9e --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/rx.go @@ -0,0 +1,83 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pipe + +// Rx is the receive side of the shared memory ring buffer. +type Rx struct { + p pipe + + tail uint64 + head uint64 +} + +// Init initializes the receive end of the pipe. In the initial state, the next +// slot to be inspected is the very first one. +func (r *Rx) Init(b []byte) { + r.p.init(b) + r.tail = 0xfffffffe * jump + r.head = r.tail +} + +// Pull reads the next buffer from the pipe, returning nil if there isn't one +// currently available. +// +// The returned slice is available until Flush() is next called. After that, it +// must not be touched. +func (r *Rx) Pull() []byte { + if r.head == r.tail+jump { + // We've already pulled the whole pipe. + return nil + } + + header := r.p.readAtomic(r.head) + if header&slotFree != 0 { + // The next slot is free, we can't pull it yet. + return nil + } + + payloadSize := header & slotSizeMask + newHead := r.head + payloadToSlotSize(payloadSize) + headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer)) + + // Check if this is a wrapping slot. If that's the case, it carries no + // data, so we just skip it and try again from the first slot. + if int64(newHead-headWrap) >= 0 { + if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 { + return nil + } + + if r.tail == r.head { + // If this is the first pull since the last Flush() + // call, we flush the state so that the sender can use + // this space if it needs to. + r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head)) + r.tail = newHead + } + + r.head = newHead + return r.Pull() + } + + // Grab the buffer before updating r.head. + b := r.p.data(r.head, payloadSize) + r.head = newHead + return b +} + +// Flush tells the transmitter that all buffers pulled since the last Flush() +// have been used, so the transmitter is free to used their slots for further +// transmission. +func (r *Rx) Flush() { + if r.head == r.tail { + return + } + r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail)) + r.tail = r.head +} + +// Bytes returns the byte slice on which the pipe operates. +func (r *Rx) Bytes() []byte { + return r.p.buffer +} diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go new file mode 100644 index 000000000..374f515ab --- /dev/null +++ b/pkg/tcpip/link/sharedmem/pipe/tx.go @@ -0,0 +1,151 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pipe + +// Tx is the transmit side of the shared memory ring buffer. +type Tx struct { + p pipe + maxPayloadSize uint64 + + head uint64 + tail uint64 + next uint64 + + tailHeader uint64 +} + +// Init initializes the transmit end of the pipe. In the initial state, the next +// slot to be written is the very first one, and the transmitter has the whole +// ring buffer available to it. +func (t *Tx) Init(b []byte) { + t.p.init(b) + // maxPayloadSize excludes the header of the payload, and the header + // of the wrapping message. + t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader + t.tail = 0xfffffffe * jump + t.next = t.tail + t.head = t.tail + jump + t.p.write(t.tail, slotFree) +} + +// Capacity determines how many records of the given size can be written to the +// pipe before it fills up. +func (t *Tx) Capacity(recordSize uint64) uint64 { + available := uint64(len(t.p.buffer)) - sizeOfSlotHeader + entryLen := payloadToSlotSize(recordSize) + return available / entryLen +} + +// Push reserves "payloadSize" bytes for transmission in the pipe. The caller +// populates the returned slice with the data to be transferred and enventually +// calls Flush() to make the data visible to the reader, or Abort() to make the +// pipe forget all Push() calls since the last Flush(). +// +// The returned slice is available until Flush() or Abort() is next called. +// After that, it must not be touched. +func (t *Tx) Push(payloadSize uint64) []byte { + // Fail request if we know we will never have enough room. + if payloadSize > t.maxPayloadSize { + return nil + } + + totalLen := payloadToSlotSize(payloadSize) + newNext := t.next + totalLen + nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer)) + if int64(newNext-nextWrap) >= 0 { + // The new buffer would overflow the pipe, so we push a wrapping + // slot, then try to add the actual slot to the front of the + // pipe. + newNext = (newNext & revolutionMask) + jump + wrappingPayloadSize := slotToPayloadSize(newNext - t.next) + if !t.reclaim(newNext) { + return nil + } + + oldNext := t.next + t.next = newNext + if oldNext != t.tail { + t.p.write(oldNext, wrappingPayloadSize) + } else { + t.tailHeader = wrappingPayloadSize + t.Flush() + } + + newNext += totalLen + } + + // Check that we have enough room for the buffer. + if !t.reclaim(newNext) { + return nil + } + + if t.next != t.tail { + t.p.write(t.next, payloadSize) + } else { + t.tailHeader = payloadSize + } + + // Grab the buffer before updating t.next. + b := t.p.data(t.next, payloadSize) + t.next = newNext + + return b +} + +// reclaim attempts to advance the head until at least newNext. If the head is +// already at or beyond newNext, nothing happens and true is returned; otherwise +// it tries to reclaim slots that have already been consumed by the receive end +// of the pipe (they will be marked as free) and returns a boolean indicating +// whether it was successful in reclaiming enough slots. +func (t *Tx) reclaim(newNext uint64) bool { + for int64(newNext-t.head) > 0 { + // Can't reclaim if slot is not free. + header := t.p.readAtomic(t.head) + if header&slotFree == 0 { + return false + } + + payloadSize := header & slotSizeMask + newHead := t.head + payloadToSlotSize(payloadSize) + + // Check newHead is within bounds and valid. + if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) { + return false + } + + t.head = newHead + } + + return true +} + +// Abort causes all Push() calls since the last Flush() to be forgotten and +// therefore they will not be made visible to the receiver. +func (t *Tx) Abort() { + t.next = t.tail +} + +// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever +// is the most recent] to be made visible to the receiver. +func (t *Tx) Flush() { + if t.next == t.tail { + // Nothing to do if there are no pushed buffers. + return + } + + if t.next != t.head { + // The receiver will spin in t.next, so we must make sure that + // the slotFree bit is set. + t.p.write(t.next, slotFree) + } + + t.p.writeAtomic(t.tail, t.tailHeader) + t.tail = t.next +} + +// Bytes returns the byte slice on which the pipe operates. +func (t *Tx) Bytes() []byte { + return t.p.buffer +} diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD new file mode 100644 index 000000000..56ea4641d --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -0,0 +1,28 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "queue", + srcs = [ + "rx.go", + "tx.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue", + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//pkg/tcpip/link/sharedmem/pipe", + ], +) + +go_test( + name = "queue_test", + srcs = [ + "queue_test.go", + ], + embed = [":queue"], + deps = [ + "//pkg/tcpip/link/sharedmem/pipe", + ], +) diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go new file mode 100644 index 000000000..b022c389c --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/queue_test.go @@ -0,0 +1,507 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package queue + +import ( + "encoding/binary" + "reflect" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe" +) + +func TestBasicTxQueue(t *testing.T) { + // Tests that a basic transmit on a queue works, and that completion + // gets properly reported as well. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Tx + q.Init(pb1, pb2) + + // Enqueue two buffers. + b := []TxBuffer{ + {nil, 100, 60}, + {nil, 200, 40}, + } + + b[0].Next = &b[1] + + const usedID = 1002 + const usedTotalSize = 100 + if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { + t.Fatalf("Enqueue failed on empty queue") + } + + // Check the contents of the pipe. + d := rxp.Pull() + if d == nil { + t.Fatalf("Tx pipe is empty after Enqueue") + } + + want := []byte{ + 234, 3, 0, 0, 0, 0, 0, 0, // id + 100, 0, 0, 0, // total size + 0, 0, 0, 0, // reserved + 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 + 60, 0, 0, 0, // size 1 + 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 + 40, 0, 0, 0, // size 2 + } + + if !reflect.DeepEqual(want, d) { + t.Fatalf("Bad posted packet: got %v, want %v", d, want) + } + + rxp.Flush() + + // Check that there are no completions yet. + if _, ok := q.CompletedPacket(); ok { + t.Fatalf("Packet reported as completed too soon") + } + + // Post a completion. + d = txp.Push(8) + if d == nil { + t.Fatalf("Unable to push to rx pipe") + } + binary.LittleEndian.PutUint64(d, usedID) + txp.Flush() + + // Check that completion is properly reported. + id, ok := q.CompletedPacket() + if !ok { + t.Fatalf("Completion not reported") + } + + if id != usedID { + t.Fatalf("Bad completion id: got %v, want %v", id, usedID) + } +} + +func TestBasicRxQueue(t *testing.T) { + // Tests that a basic receive on a queue works. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Rx + q.Init(pb1, pb2, nil) + + // Post two buffers. + b := []RxBuffer{ + {100, 60, 1077, 0}, + {200, 40, 2123, 0}, + } + + if !q.PostBuffers(b) { + t.Fatalf("PostBuffers failed on empty queue") + } + + // Check the contents of the pipe. + want := [][]byte{ + { + 100, 0, 0, 0, 0, 0, 0, 0, // Offset1 + 60, 0, 0, 0, // Size1 + 0, 0, 0, 0, // Remaining in group 1 + 0, 0, 0, 0, 0, 0, 0, 0, // User data 1 + 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 + }, + { + 200, 0, 0, 0, 0, 0, 0, 0, // Offset2 + 40, 0, 0, 0, // Size2 + 0, 0, 0, 0, // Remaining in group 2 + 0, 0, 0, 0, 0, 0, 0, 0, // User data 2 + 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 + }, + } + + for i := range b { + d := rxp.Pull() + if d == nil { + t.Fatalf("Tx pipe is empty after PostBuffers") + } + + if !reflect.DeepEqual(want[i], d) { + t.Fatalf("Bad posted packet: got %v, want %v", d, want[i]) + } + + rxp.Flush() + } + + // Check that there are no completions. + if _, n := q.Dequeue(nil); n != 0 { + t.Fatalf("Packet reported as received too soon") + } + + // Post a completion. + d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) + if d == nil { + t.Fatalf("Unable to push to rx pipe") + } + + copy(d, []byte{ + 100, 0, 0, 0, // packet size + 0, 0, 0, 0, // reserved + + 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 + 60, 0, 0, 0, // size 1 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 + 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 + + 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 + 40, 0, 0, 0, // size 2 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 + 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 + }) + + txp.Flush() + + // Check that completion is properly reported. + bufs, n := q.Dequeue(nil) + if n != 100 { + t.Fatalf("Bad packet size: got %v, want %v", n, 100) + } + + if !reflect.DeepEqual(bufs, b) { + t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b) + } +} + +func TestBadTxCompletion(t *testing.T) { + // Check that tx completions with bad sizes are properly ignored. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Tx + q.Init(pb1, pb2) + + // Post a completion that is too short, and check that it is ignored. + if d := txp.Push(7); d == nil { + t.Fatalf("Unable to push to rx pipe") + } + txp.Flush() + + if _, ok := q.CompletedPacket(); ok { + t.Fatalf("Bad completion not ignored") + } + + // Post a completion that is too long, and check that it is ignored. + if d := txp.Push(10); d == nil { + t.Fatalf("Unable to push to rx pipe") + } + txp.Flush() + + if _, ok := q.CompletedPacket(); ok { + t.Fatalf("Bad completion not ignored") + } +} + +func TestBadRxCompletion(t *testing.T) { + // Check that bad rx completions are properly ignored. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Rx + q.Init(pb1, pb2, nil) + + // Post a completion that is too short, and check that it is ignored. + if d := txp.Push(7); d == nil { + t.Fatalf("Unable to push to rx pipe") + } + txp.Flush() + + if b, _ := q.Dequeue(nil); b != nil { + t.Fatalf("Bad completion not ignored") + } + + // Post a completion whose buffer sizes add up to less than the total + // size. + d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) + if d == nil { + t.Fatalf("Unable to push to rx pipe") + } + + copy(d, []byte{ + 100, 0, 0, 0, // packet size + 0, 0, 0, 0, // reserved + + 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 + 10, 0, 0, 0, // size 1 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 + 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 + + 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 + 10, 0, 0, 0, // size 2 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 + 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 + }) + + txp.Flush() + if b, _ := q.Dequeue(nil); b != nil { + t.Fatalf("Bad completion not ignored") + } + + // Post a completion whose buffer sizes will cause a 32-bit overflow, + // but adds up to the right number. + d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) + if d == nil { + t.Fatalf("Unable to push to rx pipe") + } + + copy(d, []byte{ + 100, 0, 0, 0, // packet size + 0, 0, 0, 0, // reserved + + 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 + 255, 255, 255, 255, // size 1 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 + 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 + + 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 + 101, 0, 0, 0, // size 2 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 + 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 + }) + + txp.Flush() + if b, _ := q.Dequeue(nil); b != nil { + t.Fatalf("Bad completion not ignored") + } +} + +func TestFillTxPipe(t *testing.T) { + // Check that transmitting a new buffer when the buffer pipe is full + // fails gracefully. + pb1 := make([]byte, 104) + pb2 := make([]byte, 104) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Tx + q.Init(pb1, pb2) + + // Transmit twice, which should fill the tx pipe. + b := []TxBuffer{ + {nil, 100, 60}, + {nil, 200, 40}, + } + + b[0].Next = &b[1] + + const usedID = 1002 + const usedTotalSize = 100 + for i := uint64(0); i < 2; i++ { + if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) { + t.Fatalf("Failed to transmit buffer") + } + } + + // Transmit another packet now that the tx pipe is full. + if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) { + t.Fatalf("Enqueue succeeded when tx pipe is full") + } +} + +func TestFillRxPipe(t *testing.T) { + // Check that posting a new buffer when the buffer pipe is full fails + // gracefully. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Rx + q.Init(pb1, pb2, nil) + + // Post a buffer twice, it should fill the tx pipe. + b := []RxBuffer{ + {100, 60, 1077, 0}, + } + + for i := 0; i < 2; i++ { + if !q.PostBuffers(b) { + t.Fatalf("PostBuffers failed on non-full queue") + } + } + + // Post another buffer now that the tx pipe is full. + if q.PostBuffers(b) { + t.Fatalf("PostBuffers succeeded on full queue") + } +} + +func TestLotsOfTransmissions(t *testing.T) { + // Make sure pipes are being properly flushed when transmitting packets. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Tx + q.Init(pb1, pb2) + + // Prepare packet with two buffers. + b := []TxBuffer{ + {nil, 100, 60}, + {nil, 200, 40}, + } + + b[0].Next = &b[1] + + const usedID = 1002 + const usedTotalSize = 100 + + // Post 100000 packets and completions. + for i := 100000; i > 0; i-- { + if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) { + t.Fatalf("Enqueue failed on non-full queue") + } + + if d := rxp.Pull(); d == nil { + t.Fatalf("Tx pipe is empty after Enqueue") + } + rxp.Flush() + + d := txp.Push(8) + if d == nil { + t.Fatalf("Unable to write to rx pipe") + } + binary.LittleEndian.PutUint64(d, usedID) + txp.Flush() + if _, ok := q.CompletedPacket(); !ok { + t.Fatalf("Completion not returned") + } + } +} + +func TestLotsOfReceptions(t *testing.T) { + // Make sure pipes are being properly flushed when receiving packets. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var rxp pipe.Rx + rxp.Init(pb1) + + var txp pipe.Tx + txp.Init(pb2) + + var q Rx + q.Init(pb1, pb2, nil) + + // Prepare for posting two buffers. + b := []RxBuffer{ + {100, 60, 1077, 0}, + {200, 40, 2123, 0}, + } + + // Post 100000 buffers and completions. + for i := 100000; i > 0; i-- { + if !q.PostBuffers(b) { + t.Fatalf("PostBuffers failed on non-full queue") + } + + if d := rxp.Pull(); d == nil { + t.Fatalf("Tx pipe is empty after PostBuffers") + } + rxp.Flush() + + if d := rxp.Pull(); d == nil { + t.Fatalf("Tx pipe is empty after PostBuffers") + } + rxp.Flush() + + d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer) + if d == nil { + t.Fatalf("Unable to push to rx pipe") + } + + copy(d, []byte{ + 100, 0, 0, 0, // packet size + 0, 0, 0, 0, // reserved + + 100, 0, 0, 0, 0, 0, 0, 0, // offset 1 + 60, 0, 0, 0, // size 1 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 1 + 53, 4, 0, 0, 0, 0, 0, 0, // ID 1 + + 200, 0, 0, 0, 0, 0, 0, 0, // offset 2 + 40, 0, 0, 0, // size 2 + 0, 0, 0, 0, 0, 0, 0, 0, // user data 2 + 75, 8, 0, 0, 0, 0, 0, 0, // ID 2 + }) + + txp.Flush() + + if _, n := q.Dequeue(nil); n == 0 { + t.Fatalf("Dequeue failed when there is a completion") + } + } +} + +func TestRxEnableNotification(t *testing.T) { + // Check that enabling nofifications results in properly updated state. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var state uint32 + var q Rx + q.Init(pb1, pb2, &state) + + q.EnableNotification() + if state != eventFDEnabled { + t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled) + } +} + +func TestRxDisableNotification(t *testing.T) { + // Check that disabling nofifications results in properly updated state. + pb1 := make([]byte, 100) + pb2 := make([]byte, 100) + + var state uint32 + var q Rx + q.Init(pb1, pb2, &state) + + q.DisableNotification() + if state != eventFDDisabled { + t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled) + } +} diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go new file mode 100644 index 000000000..91bb57190 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/rx.go @@ -0,0 +1,211 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package queue provides the implementation of transmit and receive queues +// based on shared memory ring buffers. +package queue + +import ( + "encoding/binary" + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe" +) + +const ( + // Offsets within a posted buffer. + postedOffset = 0 + postedSize = 8 + postedRemainingInGroup = 12 + postedUserData = 16 + postedID = 24 + + sizeOfPostedBuffer = 32 + + // Offsets within a received packet header. + consumedPacketSize = 0 + consumedPacketReserved = 4 + + sizeOfConsumedPacketHeader = 8 + + // Offsets within a consumed buffer. + consumedOffset = 0 + consumedSize = 8 + consumedUserData = 12 + consumedID = 20 + + sizeOfConsumedBuffer = 28 + + // The following are the allowed states of the shared data area. + eventFDUninitialized = 0 + eventFDDisabled = 1 + eventFDEnabled = 2 +) + +// RxBuffer is the descriptor of a receive buffer. +type RxBuffer struct { + Offset uint64 + Size uint32 + ID uint64 + UserData uint64 +} + +// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx +// pipe is used to "post" buffers, while the rx pipe is used to receive packets +// whose contents have been written to previously posted buffers. +// +// This struct is thread-compatible. +type Rx struct { + tx pipe.Tx + rx pipe.Rx + sharedEventFDState *uint32 +} + +// Init initializes the receive queue with the given pipes, and shared state +// pointer -- the latter is used to enable/disable eventfd notifications. +func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) { + r.sharedEventFDState = sharedEventFDState + r.tx.Init(tx) + r.rx.Init(rx) +} + +// EnableNotification updates the shared state such that the peer will notify +// the eventfd when there are packets to be dequeued. +func (r *Rx) EnableNotification() { + atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled) +} + +// DisableNotification updates the shared state such that the peer will not +// notify the eventfd. +func (r *Rx) DisableNotification() { + atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled) +} + +// PostedBuffersLimit returns the maximum number of buffers that can be posted +// before the tx queue fills up. +func (r *Rx) PostedBuffersLimit() uint64 { + return r.tx.Capacity(sizeOfPostedBuffer) +} + +// PostBuffers makes the given buffers available for receiving data from the +// peer. Once they are posted, the peer is free to write to them and will +// eventually post them back for consumption. +func (r *Rx) PostBuffers(buffers []RxBuffer) bool { + for i := range buffers { + b := r.tx.Push(sizeOfPostedBuffer) + if b == nil { + r.tx.Abort() + return false + } + + pb := &buffers[i] + binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset) + binary.LittleEndian.PutUint32(b[postedSize:], pb.Size) + binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0) + binary.LittleEndian.PutUint64(b[postedUserData:], pb.UserData) + binary.LittleEndian.PutUint64(b[postedID:], pb.ID) + } + + r.tx.Flush() + + return true +} + +// Dequeue receives buffers that have been previously posted by PostBuffers() +// and that have been filled by the peer and posted back. +// +// This is similar to append() in that new buffers are appended to "bufs", with +// reallocation only if "bufs" doesn't have enough capacity. +func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) { + for { + outBufs := bufs + + // Pull the next descriptor from the rx pipe. + b := r.rx.Pull() + if b == nil { + return bufs, 0 + } + + if len(b) < sizeOfConsumedPacketHeader { + log.Warningf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader) + r.rx.Flush() + continue + } + + totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:]) + + // Calculate the number of buffer descriptors and copy them + // over to the output. + count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer + offset := sizeOfConsumedPacketHeader + buffersSize := uint32(0) + for i := count; i > 0; i-- { + s := binary.LittleEndian.Uint32(b[offset+consumedSize:]) + buffersSize += s + if buffersSize < s { + // The buffer size overflows an unsigned 32-bit + // integer, so break out and force it to be + // ignored. + totalDataSize = 1 + buffersSize = 0 + break + } + + outBufs = append(outBufs, RxBuffer{ + Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]), + Size: s, + ID: binary.LittleEndian.Uint64(b[offset+consumedID:]), + }) + + offset += sizeOfConsumedBuffer + } + + r.rx.Flush() + + if buffersSize < totalDataSize { + // The descriptor is corrupted, ignore it. + log.Warningf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize) + continue + } + + return outBufs, totalDataSize + } +} + +// Bytes returns the byte slices on which the queue operates. +func (r *Rx) Bytes() (tx, rx []byte) { + return r.tx.Bytes(), r.rx.Bytes() +} + +// DecodeRxBufferHeader decodes the header of a buffer posted on an rx queue. +func DecodeRxBufferHeader(b []byte) RxBuffer { + return RxBuffer{ + Offset: binary.LittleEndian.Uint64(b[postedOffset:]), + Size: binary.LittleEndian.Uint32(b[postedSize:]), + ID: binary.LittleEndian.Uint64(b[postedID:]), + UserData: binary.LittleEndian.Uint64(b[postedUserData:]), + } +} + +// RxCompletionSize returns the number of bytes needed to encode an rx +// completion containing "count" buffers. +func RxCompletionSize(count int) uint64 { + return sizeOfConsumedPacketHeader + uint64(count)*sizeOfConsumedBuffer +} + +// EncodeRxCompletion encodes an rx completion header. +func EncodeRxCompletion(b []byte, size, reserved uint32) { + binary.LittleEndian.PutUint32(b[consumedPacketSize:], size) + binary.LittleEndian.PutUint32(b[consumedPacketReserved:], reserved) +} + +// EncodeRxCompletionBuffer encodes the i-th rx completion buffer header. +func EncodeRxCompletionBuffer(b []byte, i int, rxb RxBuffer) { + b = b[RxCompletionSize(i):] + binary.LittleEndian.PutUint64(b[consumedOffset:], rxb.Offset) + binary.LittleEndian.PutUint32(b[consumedSize:], rxb.Size) + binary.LittleEndian.PutUint64(b[consumedUserData:], rxb.UserData) + binary.LittleEndian.PutUint64(b[consumedID:], rxb.ID) +} diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go new file mode 100644 index 000000000..b04fb163b --- /dev/null +++ b/pkg/tcpip/link/sharedmem/queue/tx.go @@ -0,0 +1,141 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package queue + +import ( + "encoding/binary" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe" +) + +const ( + // Offsets within a packet header. + packetID = 0 + packetSize = 8 + packetReserved = 12 + + sizeOfPacketHeader = 16 + + // Offsets with a buffer descriptor + bufferOffset = 0 + bufferSize = 8 + + sizeOfBufferDescriptor = 12 +) + +// TxBuffer is the descriptor of a transmit buffer. +type TxBuffer struct { + Next *TxBuffer + Offset uint64 + Size uint32 +} + +// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the +// tx pipe is used to request the transmission of packets, while the rx pipe +// is used to receive which transmissions have completed. +// +// This struct is thread-compatible. +type Tx struct { + tx pipe.Tx + rx pipe.Rx +} + +// Init initializes the transmit queue with the given pipes. +func (t *Tx) Init(tx, rx []byte) { + t.tx.Init(tx) + t.rx.Init(rx) +} + +// Enqueue queues the given linked list of buffers for transmission as one +// packet. While it is queued, the caller must not modify them. +func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool { + // Reserve room in the tx pipe. + totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor + + b := t.tx.Push(totalLen) + if b == nil { + return false + } + + // Initialize the packet and buffer descriptors. + binary.LittleEndian.PutUint64(b[packetID:], id) + binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen) + binary.LittleEndian.PutUint32(b[packetReserved:], 0) + + offset := sizeOfPacketHeader + for i := bufferCount; i != 0; i-- { + binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset) + binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size) + offset += sizeOfBufferDescriptor + buffer = buffer.Next + } + + t.tx.Flush() + + return true +} + +// CompletedPacket returns the id of the last completed transmission. The +// returned id, if any, refers to a value passed on a previous call to +// Enqueue(). +func (t *Tx) CompletedPacket() (id uint64, ok bool) { + for { + b := t.rx.Pull() + if b == nil { + return 0, false + } + + if len(b) != 8 { + t.rx.Flush() + log.Warningf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8) + continue + } + + v := binary.LittleEndian.Uint64(b) + + t.rx.Flush() + + return v, true + } +} + +// Bytes returns the byte slices on which the queue operates. +func (t *Tx) Bytes() (tx, rx []byte) { + return t.tx.Bytes(), t.rx.Bytes() +} + +// TxPacketInfo holds information about a packet sent on a tx queue. +type TxPacketInfo struct { + ID uint64 + Size uint32 + Reserved uint32 + BufferCount int +} + +// DecodeTxPacketHeader decodes the header of a packet sent over a tx queue. +func DecodeTxPacketHeader(b []byte) TxPacketInfo { + return TxPacketInfo{ + ID: binary.LittleEndian.Uint64(b[packetID:]), + Size: binary.LittleEndian.Uint32(b[packetSize:]), + Reserved: binary.LittleEndian.Uint32(b[packetReserved:]), + BufferCount: (len(b) - sizeOfPacketHeader) / sizeOfBufferDescriptor, + } +} + +// DecodeTxBufferHeader decodes the header of the i-th buffer of a packet sent +// over a tx queue. +func DecodeTxBufferHeader(b []byte, i int) TxBuffer { + b = b[sizeOfPacketHeader+i*sizeOfBufferDescriptor:] + return TxBuffer{ + Offset: binary.LittleEndian.Uint64(b[bufferOffset:]), + Size: binary.LittleEndian.Uint32(b[bufferSize:]), + } +} + +// EncodeTxCompletion encodes a tx completion header. +func EncodeTxCompletion(b []byte, id uint64) { + binary.LittleEndian.PutUint64(b, id) +} diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go new file mode 100644 index 000000000..951ed966b --- /dev/null +++ b/pkg/tcpip/link/sharedmem/rx.go @@ -0,0 +1,147 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sharedmem + +import ( + "sync/atomic" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +// rx holds all state associated with an rx queue. +type rx struct { + data []byte + sharedData []byte + q queue.Rx + eventFD int +} + +// init initializes all state needed by the rx queue based on the information +// provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (r *rx) init(mtu uint32, c *QueueConfig) error { + // Map in all buffers. + txPipe, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + + rxPipe, err := getBuffer(c.RxPipeFD) + if err != nil { + syscall.Munmap(txPipe) + return err + } + + data, err := getBuffer(c.DataFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + return err + } + + sharedData, err := getBuffer(c.SharedDataFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + syscall.Munmap(data) + return err + } + + // Duplicate the eventFD so that caller can close it but we can still + // use it. + efd, err := syscall.Dup(c.EventFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + syscall.Munmap(data) + syscall.Munmap(sharedData) + return err + } + + // Set the eventfd as non-blocking. + if err := syscall.SetNonblock(efd, true); err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + syscall.Munmap(data) + syscall.Munmap(sharedData) + syscall.Close(efd) + return err + } + + // Initialize state based on buffers. + r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData)) + r.data = data + r.eventFD = efd + r.sharedData = sharedData + + return nil +} + +// cleanup releases all resources allocated during init(). It must only be +// called if init() has previously succeeded. +func (r *rx) cleanup() { + a, b := r.q.Bytes() + syscall.Munmap(a) + syscall.Munmap(b) + + syscall.Munmap(r.data) + syscall.Munmap(r.sharedData) + syscall.Close(r.eventFD) +} + +// postAndReceive posts the provided buffers (if any), and then tries to read +// from the receive queue. +// +// Capacity permitting, it reuses the posted buffer slice to store the buffers +// that were read as well. +// +// This function will block if there aren't any available packets. +func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.RxBuffer, uint32) { + // Post the buffers first. If we cannot post, sleep until we can. We + // never post more than will fit concurrently, so it's safe to wait + // until enough room is available. + if len(b) != 0 && !r.q.PostBuffers(b) { + r.q.EnableNotification() + for !r.q.PostBuffers(b) { + var tmp [8]byte + rawfile.BlockingRead(r.eventFD, tmp[:]) + if atomic.LoadUint32(stopRequested) != 0 { + r.q.DisableNotification() + return nil, 0 + } + } + r.q.DisableNotification() + } + + // Read the next set of descriptors. + b, n := r.q.Dequeue(b[:0]) + if len(b) != 0 { + return b, n + } + + // Data isn't immediately available. Enable eventfd notifications. + r.q.EnableNotification() + for { + b, n = r.q.Dequeue(b) + if len(b) != 0 { + break + } + + // Wait for notification. + var tmp [8]byte + rawfile.BlockingRead(r.eventFD, tmp[:]) + if atomic.LoadUint32(stopRequested) != 0 { + r.q.DisableNotification() + return nil, 0 + } + } + r.q.DisableNotification() + + return b, n +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go new file mode 100644 index 000000000..2c0f1b294 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -0,0 +1,240 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sharedmem provides the implemention of data-link layer endpoints +// backed by shared memory. +// +// Shared memory endpoints can be used in the networking stack by calling New() +// to create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package sharedmem + +import ( + "sync" + "sync/atomic" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// QueueConfig holds all the file descriptors needed to describe a tx or rx +// queue over shared memory. It is used when creating new shared memory +// endpoints to describe tx and rx queues. +type QueueConfig struct { + // DataFD is a file descriptor for the file that contains the data to + // be transmitted via this queue. Descriptors contain offsets within + // this file. + DataFD int + + // EventFD is a file descriptor for the event that is signaled when + // data is becomes available in this queue. + EventFD int + + // TxPipeFD is a file descriptor for the tx pipe associated with the + // queue. + TxPipeFD int + + // RxPipeFD is a file descriptor for the rx pipe associated with the + // queue. + RxPipeFD int + + // SharedDataFD is a file descriptor for the file that contains shared + // state between the two ends of the queue. This data specifies, for + // example, whether EventFD signaling is enabled or disabled. + SharedDataFD int +} + +type endpoint struct { + // mtu (maximum transmission unit) is the maximum size of a packet. + mtu uint32 + + // bufferSize is the size of each individual buffer. + bufferSize uint32 + + // addr is the local address of this endpoint. + addr tcpip.LinkAddress + + // rx is the receive queue. + rx rx + + // stopRequested is to be accessed atomically only, and determines if + // the worker goroutines should stop. + stopRequested uint32 + + // Wait group used to indicate that all workers have stopped. + completed sync.WaitGroup + + // mu protects the following fields. + mu sync.Mutex + + // tx is the transmit queue. + tx tx + + // workerStarted specifies whether the worker goroutine was started. + workerStarted bool +} + +// New creates a new shared-memory-based endpoint. Buffers will be broken up +// into buffers of "bufferSize" bytes. +func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) { + e := &endpoint{ + mtu: mtu, + bufferSize: bufferSize, + addr: addr, + } + + if err := e.tx.init(bufferSize, &tx); err != nil { + return 0, err + } + + if err := e.rx.init(bufferSize, &rx); err != nil { + e.tx.cleanup() + return 0, err + } + + return stack.RegisterLinkEndpoint(e), nil +} + +// Close frees all resources associated with the endpoint. +func (e *endpoint) Close() { + // Tell dispatch goroutine to stop, then write to the eventfd so that + // it wakes up in case it's sleeping. + atomic.StoreUint32(&e.stopRequested, 1) + syscall.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Cleanup the queues inline if the worker hasn't started yet; we also + // know it won't start from now on because stopRequested is set to 1. + e.mu.Lock() + workerPresent := e.workerStarted + e.mu.Unlock() + + if !workerPresent { + e.tx.cleanup() + e.rx.cleanup() + } +} + +// Wait waits until all workers have stopped after a Close() call. +func (e *endpoint) Wait() { + e.completed.Wait() +} + +// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that +// reads packets from the rx queue. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 { + e.workerStarted = true + e.completed.Add(1) + go e.dispatchLoop(dispatcher) // S/R-FIXME + } + e.mu.Unlock() +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *endpoint) MTU() uint32 { + return e.mtu - header.EthernetMinimumSize +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the +// ethernet frame header size. +func (*endpoint) MaxHeaderLength() uint16 { + return header.EthernetMinimumSize +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local +// link address. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + // Add the ethernet header here. + eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth.Encode(&header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + SrcAddr: e.addr, + Type: protocol, + }) + + // Transmit the packet. + e.mu.Lock() + ok := e.tx.transmit(hdr.UsedBytes(), payload) + e.mu.Unlock() + + if !ok { + return tcpip.ErrWouldBlock + } + + return nil +} + +// dispatchLoop reads packets from the rx queue in a loop and dispatches them +// to the network stack. +func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { + // Post initial set of buffers. + limit := e.rx.q.PostedBuffersLimit() + if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l { + limit = l + } + for i := uint64(0); i < limit; i++ { + b := queue.RxBuffer{ + Offset: i * uint64(e.bufferSize), + Size: e.bufferSize, + ID: i, + } + if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) { + log.Warningf("Unable to post %v-th buffer", i) + } + } + + // Read in a loop until a stop is requested. + var rxb []queue.RxBuffer + views := []buffer.View{nil} + vv := buffer.NewVectorisedView(0, views) + for atomic.LoadUint32(&e.stopRequested) == 0 { + var n uint32 + rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested) + + // Copy data from the shared area to its own buffer, then + // prepare to repost the buffer. + b := make([]byte, n) + offset := uint32(0) + for i := range rxb { + copy(b[offset:], e.rx.data[rxb[i].Offset:][:rxb[i].Size]) + offset += rxb[i].Size + + rxb[i].Size = e.bufferSize + } + + if n < header.EthernetMinimumSize { + continue + } + + // Send packet up the stack. + eth := header.Ethernet(b) + views[0] = b[header.EthernetMinimumSize:] + vv.SetSize(int(n) - header.EthernetMinimumSize) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.Type(), &vv) + } + + // Clean state. + e.tx.cleanup() + e.rx.cleanup() + + e.completed.Done() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go new file mode 100644 index 000000000..f71e4751f --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -0,0 +1,703 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sharedmem + +import ( + "io/ioutil" + "math/rand" + "os" + "reflect" + "sync" + "syscall" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + localLinkAddr = "\xde\xad\xbe\xef\x56\x78" + remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" + + queueDataSize = 1024 * 1024 + queuePipeSize = 4096 +) + +type queueBuffers struct { + data []byte + rx pipe.Tx + tx pipe.Rx +} + +func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) { + // Prepare tx pipe. + b, err := getBuffer(c.TxPipeFD) + if err != nil { + t.Fatalf("getBuffer failed: %v", err) + } + q.tx.Init(b) + + // Prepare rx pipe. + b, err = getBuffer(c.RxPipeFD) + if err != nil { + t.Fatalf("getBuffer failed: %v", err) + } + q.rx.Init(b) + + // Get data slice. + q.data, err = getBuffer(c.DataFD) + if err != nil { + t.Fatalf("getBuffer failed: %v", err) + } +} + +func (q *queueBuffers) cleanup() { + syscall.Munmap(q.tx.Bytes()) + syscall.Munmap(q.rx.Bytes()) + syscall.Munmap(q.data) +} + +type packetInfo struct { + addr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + vv buffer.VectorisedView +} + +type testContext struct { + t *testing.T + ep *endpoint + txCfg QueueConfig + rxCfg QueueConfig + txq queueBuffers + rxq queueBuffers + + packetCh chan struct{} + mu sync.Mutex + packets []packetInfo +} + +func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext { + var err error + c := &testContext{ + t: t, + packetCh: make(chan struct{}, 1000000), + } + c.txCfg = createQueueFDs(t, queueSizes{ + dataSize: queueDataSize, + txPipeSize: queuePipeSize, + rxPipeSize: queuePipeSize, + sharedDataSize: 4096, + }) + + c.rxCfg = createQueueFDs(t, queueSizes{ + dataSize: queueDataSize, + txPipeSize: queuePipeSize, + rxPipeSize: queuePipeSize, + sharedDataSize: 4096, + }) + + initQueue(t, &c.txq, &c.txCfg) + initQueue(t, &c.rxq, &c.rxCfg) + + id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + c.ep = stack.FindLinkEndpoint(id).(*endpoint) + c.ep.Attach(c) + + return c +} + +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + c.mu.Lock() + c.packets = append(c.packets, packetInfo{ + addr: remoteAddr, + proto: proto, + vv: vv.Clone(nil), + }) + c.mu.Unlock() + + c.packetCh <- struct{}{} +} + +func (c *testContext) cleanup() { + c.ep.Close() + closeFDs(&c.txCfg) + closeFDs(&c.rxCfg) + c.txq.cleanup() + c.rxq.cleanup() +} + +func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) { + for i := 0; i < n; i++ { + select { + case <-c.packetCh: + case <-to: + c.t.Fatalf(errorStr) + } + } +} + +func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) { + b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs))) + queue.EncodeRxCompletion(b, size, 0) + for i := range bs { + queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{ + Offset: bs[i].Offset, + Size: bs[i].Size, + ID: bs[i].ID, + }) + } +} + +func randomFill(b []byte) { + for i := range b { + b[i] = byte(rand.Intn(256)) + } +} + +func shuffle(b []int) { + for i := len(b) - 1; i >= 0; i-- { + j := rand.Intn(i + 1) + b[i], b[j] = b[j], b[i] + } +} + +func createFile(t *testing.T, size int64, initQueue bool) int { + tmpDir := os.Getenv("TEST_TMPDIR") + if tmpDir == "" { + tmpDir = os.Getenv("TMPDIR") + } + f, err := ioutil.TempFile(tmpDir, "sharedmem_test") + if err != nil { + t.Fatalf("TempFile failed: %v", err) + } + defer f.Close() + syscall.Unlink(f.Name()) + + if initQueue { + // Write the "slot-free" flag in the initial queue. + _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0) + if err != nil { + t.Fatalf("WriteAt failed: %v", err) + } + } + + fd, err := syscall.Dup(int(f.Fd())) + if err != nil { + t.Fatalf("Dup failed: %v", err) + } + + if err := syscall.Ftruncate(fd, size); err != nil { + syscall.Close(fd) + t.Fatalf("Ftruncate failed: %v", err) + } + + return fd +} + +func closeFDs(c *QueueConfig) { + syscall.Close(c.DataFD) + syscall.Close(c.EventFD) + syscall.Close(c.TxPipeFD) + syscall.Close(c.RxPipeFD) + syscall.Close(c.SharedDataFD) +} + +type queueSizes struct { + dataSize int64 + txPipeSize int64 + rxPipeSize int64 + sharedDataSize int64 +} + +func createQueueFDs(t *testing.T, s queueSizes) QueueConfig { + fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0) + if err != 0 { + t.Fatalf("eventfd failed: %v", error(err)) + } + + return QueueConfig{ + EventFD: int(fd), + DataFD: createFile(t, s.dataSize, false), + TxPipeFD: createFile(t, s.txPipeSize, true), + RxPipeFD: createFile(t, s.rxPipeSize, true), + SharedDataFD: createFile(t, s.sharedDataSize, false), + } +} + +// TestSimpleSend sends 1000 packets with random header and payload sizes, +// then checks that the right payload is received on the shared memory queues. +func TestSimpleSend(t *testing.T) { + c := newTestContext(t, 20000, 1500, localLinkAddr) + defer c.cleanup() + + // Prepare route. + r := stack.Route{ + RemoteLinkAddress: remoteLinkAddr, + } + + for iters := 1000; iters > 0; iters-- { + // Prepare and send packet. + n := rand.Intn(10000) + hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) + hdrBuf := hdr.Prepend(n) + randomFill(hdrBuf) + + n = rand.Intn(10000) + buf := buffer.NewView(n) + randomFill(buf) + + proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) + err := c.ep.WritePacket(&r, &hdr, buf, proto) + if err != nil { + t.Fatalf("WritePacket failed: %v", err) + } + + // Receive packet. + desc := c.txq.tx.Pull() + pi := queue.DecodeTxPacketHeader(desc) + contents := make([]byte, 0, pi.Size) + for i := 0; i < pi.BufferCount; i++ { + bi := queue.DecodeTxBufferHeader(desc, i) + contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...) + } + c.txq.tx.Flush() + + if pi.Reserved != 0 { + t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved) + } + + // Check the thernet header. + ethTemplate := make(header.Ethernet, header.EthernetMinimumSize) + ethTemplate.Encode(&header.EthernetFields{ + SrcAddr: localLinkAddr, + DstAddr: remoteLinkAddr, + Type: proto, + }) + if got := contents[:header.EthernetMinimumSize]; !reflect.DeepEqual(got, []byte(ethTemplate)) { + t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate) + } + + // Compare contents skipping the ethernet header added by the + // endpoint. + merged := append(hdrBuf, buf...) + if uint32(len(contents)) < pi.Size { + t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) + } + contents = contents[:pi.Size][header.EthernetMinimumSize:] + + if !reflect.DeepEqual(contents, merged) { + t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged)) + } + + // Tell the endpoint about the completion of the write. + b := c.txq.rx.Push(8) + queue.EncodeTxCompletion(b, pi.ID) + c.txq.rx.Flush() + } +} + +// TestFillTxQueue sends packets until the queue is full. +func TestFillTxQueue(t *testing.T) { + c := newTestContext(t, 20000, 1500, localLinkAddr) + defer c.cleanup() + + // Prepare to send a packet. + r := stack.Route{ + RemoteLinkAddress: remoteLinkAddr, + } + + buf := buffer.NewView(100) + + // Each packet is uses no more than 40 bytes, so write that many packets + // until the tx queue if full. + ids := make(map[uint64]struct{}) + for i := queuePipeSize / 40; i > 0; i-- { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } + + // Check that they have different IDs. + desc := c.txq.tx.Pull() + pi := queue.DecodeTxPacketHeader(desc) + if _, ok := ids[pi.ID]; ok { + t.Fatalf("ID (%v) reused", pi.ID) + } + ids[pi.ID] = struct{}{} + } + + // Next attempt to write must fail. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber) + if want := tcpip.ErrWouldBlock; err != want { + t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + } +} + +// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets +// until the queue is full. +func TestFillTxQueueAfterBadCompletion(t *testing.T) { + c := newTestContext(t, 20000, 1500, localLinkAddr) + defer c.cleanup() + + // Send a bad completion. + queue.EncodeTxCompletion(c.txq.rx.Push(8), 1) + c.txq.rx.Flush() + + // Prepare to send a packet. + r := stack.Route{ + RemoteLinkAddress: remoteLinkAddr, + } + + buf := buffer.NewView(100) + + // Send two packets so that the id slice has at least two slots. + for i := 2; i > 0; i-- { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } + } + + // Complete the two writes twice. + for i := 2; i > 0; i-- { + pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull()) + + queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) + queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID) + c.txq.rx.Flush() + } + c.txq.tx.Flush() + + // Each packet is uses no more than 40 bytes, so write that many packets + // until the tx queue if full. + ids := make(map[uint64]struct{}) + for i := queuePipeSize / 40; i > 0; i-- { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } + + // Check that they have different IDs. + desc := c.txq.tx.Pull() + pi := queue.DecodeTxPacketHeader(desc) + if _, ok := ids[pi.ID]; ok { + t.Fatalf("ID (%v) reused", pi.ID) + } + ids[pi.ID] = struct{}{} + } + + // Next attempt to write must fail. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber) + if want := tcpip.ErrWouldBlock; err != want { + t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + } +} + +// TestFillTxMemory sends packets until the we run out of shared memory. +func TestFillTxMemory(t *testing.T) { + const bufferSize = 1500 + c := newTestContext(t, 20000, bufferSize, localLinkAddr) + defer c.cleanup() + + // Prepare to send a packet. + r := stack.Route{ + RemoteLinkAddress: remoteLinkAddr, + } + + buf := buffer.NewView(100) + + // Each packet is uses up one buffer, so write as many as possible until + // we fill the memory. + ids := make(map[uint64]struct{}) + for i := queueDataSize / bufferSize; i > 0; i-- { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } + + // Check that they have different IDs. + desc := c.txq.tx.Pull() + pi := queue.DecodeTxPacketHeader(desc) + if _, ok := ids[pi.ID]; ok { + t.Fatalf("ID (%v) reused", pi.ID) + } + ids[pi.ID] = struct{}{} + c.txq.tx.Flush() + } + + // Next attempt to write must fail. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber) + if want := tcpip.ErrWouldBlock; err != want { + t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + } +} + +// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of +// shared memory for a 2-buffer packet, but still with room for a 1-buffer +// packet. +func TestFillTxMemoryWithMultiBuffer(t *testing.T) { + const bufferSize = 1500 + c := newTestContext(t, 20000, bufferSize, localLinkAddr) + defer c.cleanup() + + // Prepare to send a packet. + r := stack.Route{ + RemoteLinkAddress: remoteLinkAddr, + } + + buf := buffer.NewView(100) + + // Each packet is uses up one buffer, so write as many as possible + // until there is only one buffer left. + for i := queueDataSize/bufferSize - 1; i > 0; i-- { + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } + + // Pull the posted buffer. + c.txq.tx.Pull() + c.txq.tx.Flush() + } + + // Attempt to write a two-buffer packet. It must fail. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + err := c.ep.WritePacket(&r, &hdr, buffer.NewView(bufferSize), header.IPv4ProtocolNumber) + if want := tcpip.ErrWouldBlock; err != want { + t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) + } + + // Attempt to write a one-buffer packet. It must succeed. + hdr = buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil { + t.Fatalf("WritePacket failed unexpectedly: %v", err) + } +} + +func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte { + for { + b := p.Pull() + if b != nil { + return b + } + + select { + case <-time.After(10 * time.Millisecond): + case <-to: + t.Fatalf(errStr) + } + } +} + +// TestSimpleReceive completes 1000 different receives with random payload and +// random number of buffers. It checks that the contents match the expected +// values. +func TestSimpleReceive(t *testing.T) { + const bufferSize = 1500 + c := newTestContext(t, 20000, bufferSize, localLinkAddr) + defer c.cleanup() + + // Check that buffers have been posted. + limit := c.ep.rx.q.PostedBuffersLimit() + timeout := time.After(2 * time.Second) + for i := uint64(0); i < limit; i++ { + bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted")) + + if want := i * bufferSize; want != bi.Offset { + t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want) + } + + if want := i; want != bi.ID { + t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want) + } + + if bufferSize != bi.Size { + t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize) + } + } + c.rxq.tx.Flush() + + // Create a slice with the indices 0..limit-1. + idx := make([]int, limit) + for i := range idx { + idx[i] = i + } + + // Complete random packets 1000 times. + for iters := 1000; iters > 0; iters-- { + // Prepare a random packet. + shuffle(idx) + n := 1 + rand.Intn(10) + bufs := make([]queue.RxBuffer, n) + contents := make([]byte, bufferSize*n-rand.Intn(500)) + randomFill(contents) + for i := range bufs { + j := idx[i] + bufs[i].Size = bufferSize + bufs[i].Offset = uint64(bufferSize * j) + bufs[i].ID = uint64(j) + + copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:]) + } + + // Push completion. + c.pushRxCompletion(uint32(len(contents)), bufs) + c.rxq.rx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Wait for packet to be received, then check it. + c.waitForPackets(1, time.After(time.Second), "Error waiting for packet") + c.mu.Lock() + rcvd := []byte(c.packets[0].vv.First()) + c.packets = c.packets[:0] + c.mu.Unlock() + + contents = contents[header.EthernetMinimumSize:] + if !reflect.DeepEqual(contents, rcvd) { + t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents) + } + + // Check that buffers have been reposted. + for i := range bufs { + bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted")) + if !reflect.DeepEqual(bi, bufs[i]) { + t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i]) + } + } + c.rxq.tx.Flush() + } +} + +// TestRxBuffersReposted tests that rx buffers get reposted after they have been +// completed. +func TestRxBuffersReposted(t *testing.T) { + const bufferSize = 1500 + c := newTestContext(t, 20000, bufferSize, localLinkAddr) + defer c.cleanup() + + // Receive all posted buffers. + limit := c.ep.rx.q.PostedBuffersLimit() + buffers := make([]queue.RxBuffer, 0, limit) + timeout := time.After(2 * time.Second) + for i := limit; i > 0; i-- { + buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers"))) + } + c.rxq.tx.Flush() + + // Check that all buffers are reposted when individually completed. + timeout = time.After(2 * time.Second) + for i := range buffers { + // Complete the buffer. + c.pushRxCompletion(buffers[i].Size, buffers[i:][:1]) + c.rxq.rx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Wait for it to be reposted. + bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) + if !reflect.DeepEqual(bi, buffers[i]) { + t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i]) + } + } + c.rxq.tx.Flush() + + // Check that all buffers are reposted when completed in pairs. + timeout = time.After(2 * time.Second) + for i := 0; i < len(buffers)/2; i++ { + // Complete with two buffers. + c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2]) + c.rxq.rx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Wait for them to be reposted. + bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) + if !reflect.DeepEqual(bi, buffers[2*i]) { + t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i]) + } + bi = queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted")) + if !reflect.DeepEqual(bi, buffers[2*i+1]) { + t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+1]) + } + } + c.rxq.tx.Flush() +} + +// TestReceivePostingIsFull checks that the endpoint will properly handle the +// case when a received buffer cannot be immediately reposted because it hasn't +// been pulled from the tx pipe yet. +func TestReceivePostingIsFull(t *testing.T) { + const bufferSize = 1500 + c := newTestContext(t, 20000, bufferSize, localLinkAddr) + defer c.cleanup() + + // Complete first posted buffer before flushing it from the tx pipe. + first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted")) + c.pushRxCompletion(first.Size, []queue.RxBuffer{first}) + c.rxq.rx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Check that packet is received. + c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") + + // Complete another buffer. + second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted")) + c.pushRxCompletion(second.Size, []queue.RxBuffer{second}) + c.rxq.rx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Check that no packet is received yet, as the worker is blocked trying + // to repost. + select { + case <-time.After(500 * time.Millisecond): + case <-c.packetCh: + t.Fatalf("Unexpected packet received") + } + + // Flush tx queue, which will allow the first buffer to be reposted, + // and the second completion to be pulled. + c.rxq.tx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Check that second packet completes. + c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet") +} + +// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to +// repost a buffer. Make sure it backs out. +func TestCloseWhileWaitingToPost(t *testing.T) { + const bufferSize = 1500 + c := newTestContext(t, 20000, bufferSize, localLinkAddr) + cleaned := false + defer func() { + if !cleaned { + c.cleanup() + } + }() + + // Complete first posted buffer before flushing it from the tx pipe. + bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted")) + c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi}) + c.rxq.rx.Flush() + syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0}) + + // Wait for packet to be indicated. + c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet") + + // Cleanup and wait for worker to complete. + c.cleanup() + cleaned = true + c.ep.Wait() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go new file mode 100644 index 000000000..52f93f480 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go @@ -0,0 +1,15 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sharedmem + +import ( + "unsafe" +) + +// sharedDataPointer converts the shared data slice into a pointer so that it +// can be used in atomic operations. +func sharedDataPointer(sharedData []byte) *uint32 { + return (*uint32)(unsafe.Pointer(&sharedData[0:4][0])) +} diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go new file mode 100644 index 000000000..bca1d79b4 --- /dev/null +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -0,0 +1,262 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sharedmem + +import ( + "math" + "syscall" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue" +) + +const ( + nilID = math.MaxUint64 +) + +// tx holds all state associated with a tx queue. +type tx struct { + data []byte + q queue.Tx + ids idManager + bufs bufferManager +} + +// init initializes all state needed by the tx queue based on the information +// provided. +// +// The caller always retains ownership of all file descriptors passed in. The +// queue implementation will duplicate any that it may need in the future. +func (t *tx) init(mtu uint32, c *QueueConfig) error { + // Map in all buffers. + txPipe, err := getBuffer(c.TxPipeFD) + if err != nil { + return err + } + + rxPipe, err := getBuffer(c.RxPipeFD) + if err != nil { + syscall.Munmap(txPipe) + return err + } + + data, err := getBuffer(c.DataFD) + if err != nil { + syscall.Munmap(txPipe) + syscall.Munmap(rxPipe) + return err + } + + // Initialize state based on buffers. + t.q.Init(txPipe, rxPipe) + t.ids.init() + t.bufs.init(0, len(data), int(mtu)) + t.data = data + + return nil +} + +// cleanup releases all resources allocated during init(). It must only be +// called if init() has previously succeeded. +func (t *tx) cleanup() { + a, b := t.q.Bytes() + syscall.Munmap(a) + syscall.Munmap(b) + syscall.Munmap(t.data) +} + +// transmit sends a packet made up of up to two buffers. Returns a boolean that +// specifies whether the packet was successfully transmitted. +func (t *tx) transmit(a, b []byte) bool { + // Pull completions from the tx queue and add their buffers back to the + // pool so that we can reuse them. + for { + id, ok := t.q.CompletedPacket() + if !ok { + break + } + + if buf := t.ids.remove(id); buf != nil { + t.bufs.free(buf) + } + } + + bSize := t.bufs.entrySize + total := uint32(len(a) + len(b)) + bufCount := (total + bSize - 1) / bSize + + // Allocate enough buffers to hold all the data. + var buf *queue.TxBuffer + for i := bufCount; i != 0; i-- { + b := t.bufs.alloc() + if b == nil { + // Failed to get all buffers. Return to the pool + // whatever we had managed to get. + if buf != nil { + t.bufs.free(buf) + } + return false + } + b.Next = buf + buf = b + } + + // Copy data into allocated buffers. + nBuf := buf + var dBuf []byte + for _, data := range [][]byte{a, b} { + for len(data) > 0 { + if len(dBuf) == 0 { + dBuf = t.data[nBuf.Offset:][:nBuf.Size] + nBuf = nBuf.Next + } + n := copy(dBuf, data) + data = data[n:] + dBuf = dBuf[n:] + } + } + + // Get an id for this packet and send it out. + id := t.ids.add(buf) + if !t.q.Enqueue(id, total, bufCount, buf) { + t.ids.remove(id) + t.bufs.free(buf) + return false + } + + return true +} + +// getBuffer returns a memory region mapped to the full contents of the given +// file descriptor. +func getBuffer(fd int) ([]byte, error) { + var s syscall.Stat_t + if err := syscall.Fstat(fd, &s); err != nil { + return nil, err + } + + // Check that size doesn't overflow an int. + if s.Size > int64(^uint(0)>>1) { + return nil, syscall.EDOM + } + + return syscall.Mmap(fd, 0, int(s.Size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED|syscall.MAP_FILE) +} + +// idDescriptor is used by idManager to either point to a tx buffer (in case +// the ID is assigned) or to the next free element (if the id is not assigned). +type idDescriptor struct { + buf *queue.TxBuffer + nextFree uint64 +} + +// idManager is a manager of tx buffer identifiers. It assigns unique IDs to +// tx buffers that are added to it; the IDs can only be reused after they have +// been removed. +// +// The ID assignments are stored so that the tx buffers can be retrieved from +// the IDs previously assigned to them. +type idManager struct { + // ids is a slice containing all tx buffers. The ID is the index into + // this slice. + ids []idDescriptor + + // freeList a list of free IDs. + freeList uint64 +} + +// init initializes the id manager. +func (m *idManager) init() { + m.freeList = nilID +} + +// add assigns an ID to the given tx buffer. +func (m *idManager) add(b *queue.TxBuffer) uint64 { + if i := m.freeList; i != nilID { + // There is an id available in the free list, just use it. + m.ids[i].buf = b + m.freeList = m.ids[i].nextFree + return i + } + + // We need to expand the id descriptor. + m.ids = append(m.ids, idDescriptor{buf: b}) + return uint64(len(m.ids) - 1) +} + +// remove retrieves the tx buffer associated with the given ID, and removes the +// ID from the assigned table so that it can be reused in the future. +func (m *idManager) remove(i uint64) *queue.TxBuffer { + if i >= uint64(len(m.ids)) { + return nil + } + + desc := &m.ids[i] + b := desc.buf + if b == nil { + // The provided id is not currently assigned. + return nil + } + + desc.buf = nil + desc.nextFree = m.freeList + m.freeList = i + + return b +} + +// bufferManager manages a buffer region broken up into smaller, equally sized +// buffers. Smaller buffers can be allocated and freed. +type bufferManager struct { + freeList *queue.TxBuffer + curOffset uint64 + limit uint64 + entrySize uint32 +} + +// init initializes the buffer manager. +func (b *bufferManager) init(initialOffset, size, entrySize int) { + b.freeList = nil + b.curOffset = uint64(initialOffset) + b.limit = uint64(initialOffset + size/entrySize*entrySize) + b.entrySize = uint32(entrySize) +} + +// alloc allocates a buffer from the manager, if one is available. +func (b *bufferManager) alloc() *queue.TxBuffer { + if b.freeList != nil { + // There is a descriptor ready for reuse in the free list. + d := b.freeList + b.freeList = d.Next + d.Next = nil + return d + } + + if b.curOffset < b.limit { + // There is room available in the never-used range, so create + // a new descriptor for it. + d := &queue.TxBuffer{ + Offset: b.curOffset, + Size: b.entrySize, + } + b.curOffset += uint64(b.entrySize) + return d + } + + return nil +} + +// free returns all buffers in the list to the buffer manager so that they can +// be reused. +func (b *bufferManager) free(d *queue.TxBuffer) { + // Find the last buffer in the list. + last := d + for last.Next != nil { + last = last.Next + } + + // Push list onto free list. + last.Next = b.freeList + b.freeList = d +} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD new file mode 100644 index 000000000..a912707c2 --- /dev/null +++ b/pkg/tcpip/link/sniffer/BUILD @@ -0,0 +1,23 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "sniffer", + srcs = [ + "pcap.go", + "sniffer.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/log", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/rawfile", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go new file mode 100644 index 000000000..6d0dd565a --- /dev/null +++ b/pkg/tcpip/link/sniffer/pcap.go @@ -0,0 +1,52 @@ +package sniffer + +import "time" + +type pcapHeader struct { + // MagicNumber is the file magic number. + MagicNumber uint32 + + // VersionMajor is the major version number. + VersionMajor uint16 + + // VersionMinor is the minor version number. + VersionMinor uint16 + + // Thiszone is the GMT to local correction. + Thiszone int32 + + // Sigfigs is the accuracy of timestamps. + Sigfigs uint32 + + // Snaplen is the max length of captured packets, in octets. + Snaplen uint32 + + // Network is the data link type. + Network uint32 +} + +const pcapPacketHeaderLen = 16 + +type pcapPacketHeader struct { + // Seconds is the timestamp seconds. + Seconds uint32 + + // Microseconds is the timestamp microseconds. + Microseconds uint32 + + // IncludedLength is the number of octets of packet saved in file. + IncludedLength uint32 + + // OriginalLength is the actual length of packet. + OriginalLength uint32 +} + +func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader { + now := time.Now() + return pcapPacketHeader{ + Seconds: uint32(now.Unix()), + Microseconds: uint32(now.Nanosecond() / 1000), + IncludedLength: incLen, + OriginalLength: orgLen, + } +} diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go new file mode 100644 index 000000000..da6969e94 --- /dev/null +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -0,0 +1,310 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sniffer provides the implementation of data-link layer endpoints that +// wrap another endpoint and logs inbound and outbound packets. +// +// Sniffer endpoints can be used in the networking stack by calling New(eID) to +// create a new endpoint, where eID is the ID of the endpoint being wrapped, +// and then passing it as an argument to Stack.CreateNIC(). +package sniffer + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "os" + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// LogPackets is a flag used to enable or disable packet logging via the log +// package. Valid values are 0 or 1. +// +// LogPackets must be accessed atomically. +var LogPackets uint32 = 1 + +// LogPacketsToFile is a flag used to enable or disable logging packets to a +// pcap file. Valid values are 0 or 1. A file must have been specified when the +// sniffer was created for this flag to have effect. +// +// LogPacketsToFile must be accessed atomically. +var LogPacketsToFile uint32 = 1 + +type endpoint struct { + dispatcher stack.NetworkDispatcher + lower stack.LinkEndpoint + file *os.File + maxPCAPLen uint32 +} + +// New creates a new sniffer link-layer endpoint. It wraps around another +// endpoint and logs packets and they traverse the endpoint. +func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { + return stack.RegisterLinkEndpoint(&endpoint{ + lower: stack.FindLinkEndpoint(lower), + }) +} + +func zoneOffset() (int32, error) { + loc, err := time.LoadLocation("Local") + if err != nil { + return 0, err + } + date := time.Date(0, 0, 0, 0, 0, 0, 0, loc) + _, offset := date.Zone() + return int32(offset), nil +} + +func writePCAPHeader(w io.Writer, maxLen uint32) error { + offset, err := zoneOffset() + if err != nil { + return err + } + return binary.Write(w, binary.BigEndian, pcapHeader{ + // From https://wiki.wireshark.org/Development/LibpcapFileFormat + MagicNumber: 0xa1b2c3d4, + + VersionMajor: 2, + VersionMinor: 4, + Thiszone: offset, + Sigfigs: 0, + Snaplen: maxLen, + Network: 101, // LINKTYPE_RAW + }) +} + +// NewWithFile creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets and they traverse the endpoint. +// +// Packets can be logged to file in the pcap format in addition to the standard +// human-readable logs. +// +// snapLen is the maximum amount of a packet to be saved. Packets with a length +// less than or equal too snapLen will be saved in their entirety. Longer +// packets will be truncated to snapLen. +func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { + if err := writePCAPHeader(file, snapLen); err != nil { + return 0, err + } + return stack.RegisterLinkEndpoint(&endpoint{ + lower: stack.FindLinkEndpoint(lower), + file: file, + maxPCAPLen: snapLen, + }), nil +} + +// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is +// called by the link-layer endpoint being wrapped when a packet arrives, and +// logs the packet before forwarding to the actual dispatcher. +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + if atomic.LoadUint32(&LogPackets) == 1 { + LogPacket("recv", protocol, vv.First(), nil) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + vs := vv.Views() + bs := make([][]byte, 1, 1+len(vs)) + var length int + for _, v := range vs { + if length+len(v) > int(e.maxPCAPLen) { + l := int(e.maxPCAPLen) - length + bs = append(bs, []byte(v)[:l]) + length += l + break + } + bs = append(bs, []byte(v)) + length += len(v) + } + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen)) + binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))) + bs[0] = buf.Bytes() + if err := rawfile.NonBlockingWriteN(int(e.file.Fd()), bs...); err != nil { + panic(err) + } + } + e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, vv) +} + +// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher +// and registers with the lower endpoint as its dispatcher so that "e" is called +// for inbound packets. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the +// lower endpoint. +func (e *endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the +// request to the lower endpoint. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards +// the request to the lower endpoint. +func (e *endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// WritePacket implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and forwards +// the request to the lower endpoint. +func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + if atomic.LoadUint32(&LogPackets) == 1 { + LogPacket("send", protocol, hdr.UsedBytes(), payload) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + bs := [][]byte{nil, hdr.UsedBytes(), payload} + var length int + + for i, b := range bs[1:] { + if rem := int(e.maxPCAPLen) - length; len(b) > rem { + b = b[:rem] + } + bs[i+1] = b + length += len(b) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen)) + binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(hdr.UsedLength()+len(payload)))) + bs[0] = buf.Bytes() + if err := rawfile.NonBlockingWriteN(int(e.file.Fd()), bs...); err != nil { + panic(err) + } + } + return e.lower.WritePacket(r, hdr, payload, protocol) +} + +// LogPacket logs the given packet. +func LogPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byte) { + // Figure out the network layer info. + var transProto uint8 + src := tcpip.Address("unknown") + dst := tcpip.Address("unknown") + id := 0 + size := uint16(0) + switch protocol { + case header.IPv4ProtocolNumber: + ipv4 := header.IPv4(b) + src = ipv4.SourceAddress() + dst = ipv4.DestinationAddress() + transProto = ipv4.Protocol() + size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) + b = b[ipv4.HeaderLength():] + id = int(ipv4.ID()) + + case header.IPv6ProtocolNumber: + ipv6 := header.IPv6(b) + src = ipv6.SourceAddress() + dst = ipv6.DestinationAddress() + transProto = ipv6.NextHeader() + size = ipv6.PayloadLength() + b = b[header.IPv6MinimumSize:] + + case header.ARPProtocolNumber: + arp := header.ARP(b) + log.Infof( + "%s arp %v (%v) -> %v (%v) valid:%v", + prefix, + tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), + tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), + arp.IsValid(), + ) + return + default: + log.Infof("%s unknown network protocol", prefix) + return + } + + // Figure out the transport layer info. + transName := "unknown" + srcPort := uint16(0) + dstPort := uint16(0) + details := "" + switch tcpip.TransportProtocolNumber(transProto) { + case header.ICMPv4ProtocolNumber: + transName = "icmp" + icmp := header.ICMPv4(b) + icmpType := "unknown" + switch icmp.Type() { + case header.ICMPv4EchoReply: + icmpType = "echo reply" + case header.ICMPv4DstUnreachable: + icmpType = "destination unreachable" + case header.ICMPv4SrcQuench: + icmpType = "source quench" + case header.ICMPv4Redirect: + icmpType = "redirect" + case header.ICMPv4Echo: + icmpType = "echo" + case header.ICMPv4TimeExceeded: + icmpType = "time exceeded" + case header.ICMPv4ParamProblem: + icmpType = "param problem" + case header.ICMPv4Timestamp: + icmpType = "timestamp" + case header.ICMPv4TimestampReply: + icmpType = "timestamp reply" + case header.ICMPv4InfoRequest: + icmpType = "info request" + case header.ICMPv4InfoReply: + icmpType = "info reply" + } + log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + return + + case header.UDPProtocolNumber: + transName = "udp" + udp := header.UDP(b) + srcPort = udp.SourcePort() + dstPort = udp.DestinationPort() + size -= header.UDPMinimumSize + + details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) + + case header.TCPProtocolNumber: + transName = "tcp" + tcp := header.TCP(b) + srcPort = tcp.SourcePort() + dstPort = tcp.DestinationPort() + size -= uint16(tcp.DataOffset()) + + // Initialize the TCP flags. + flags := tcp.Flags() + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if flags&(1< %v unknown transport protocol: %d", prefix, src, dst, transProto) + return + } + + log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) +} diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD new file mode 100644 index 000000000..d627f00f1 --- /dev/null +++ b/pkg/tcpip/link/tun/BUILD @@ -0,0 +1,12 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "tun", + srcs = ["tun_unsafe.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/tun", + visibility = [ + "//visibility:public", + ], +) diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go new file mode 100644 index 000000000..5b6c9b4ab --- /dev/null +++ b/pkg/tcpip/link/tun/tun_unsafe.go @@ -0,0 +1,50 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tun + +import ( + "syscall" + "unsafe" +) + +// Open opens the specified TUN device, sets it to non-blocking mode, and +// returns its file descriptor. +func Open(name string) (int, error) { + return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI) +} + +// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and +// returns its file descriptor. +func OpenTAP(name string) (int, error) { + return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) +} + +func open(name string, flags uint16) (int, error) { + fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0) + if err != nil { + return -1, err + } + + var ifr struct { + name [16]byte + flags uint16 + _ [22]byte + } + + copy(ifr.name[:], name) + ifr.flags = flags + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) + if errno != 0 { + syscall.Close(fd) + return -1, errno + } + + if err = syscall.SetNonblock(fd, true); err != nil { + syscall.Close(fd) + return -1, err + } + + return fd, nil +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD new file mode 100644 index 000000000..63b648be7 --- /dev/null +++ b/pkg/tcpip/link/waitable/BUILD @@ -0,0 +1,33 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "waitable", + srcs = [ + "waitable.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/waitable", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/gate", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "waitable_test", + srcs = [ + "waitable_test.go", + ], + embed = [":waitable"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go new file mode 100644 index 000000000..2c6e73f22 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable.go @@ -0,0 +1,108 @@ +// Copyright 2017 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package waitable provides the implementation of data-link layer endpoints +// that wrap other endpoints, and can wait for inflight calls to WritePacket or +// DeliverNetworkPacket to finish (and new ones to be prevented). +// +// Waitable endpoints can be used in the networking stack by calling New(eID) to +// create a new endpoint, where eID is the ID of the endpoint being wrapped, +// and then passing it as an argument to Stack.CreateNIC(). +package waitable + +import ( + "gvisor.googlesource.com/gvisor/pkg/gate" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a waitable link-layer endpoint. +type Endpoint struct { + dispatchGate gate.Gate + dispatcher stack.NetworkDispatcher + + writeGate gate.Gate + lower stack.LinkEndpoint +} + +// New creates a new waitable link-layer endpoint. It wraps around another +// endpoint and allows the caller to block new write/dispatch calls and wait for +// the inflight ones to finish before returning. +func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) { + e := &Endpoint{ + lower: stack.FindLinkEndpoint(lower), + } + return stack.RegisterLinkEndpoint(e), e +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. +// It is called by the link-layer endpoint being wrapped when a packet arrives, +// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't +// been called. +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + if !e.dispatchGate.Enter() { + return + } + + e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, vv) + e.dispatchGate.Leave() +} + +// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and +// registers with the lower endpoint as its dispatcher so that "e" is called +// for inbound packets. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the +// lower endpoint. +func (e *Endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the +// request to the lower endpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just +// forwards the request to the lower endpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the +// request to the lower endpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by +// higher-level protocols to write packets. It only forwards packets to the +// lower endpoint if Wait or WaitWrite haven't been called. +func (e *Endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + if !e.writeGate.Enter() { + return nil + } + + err := e.lower.WritePacket(r, hdr, payload, protocol) + e.writeGate.Leave() + return err +} + +// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, +// and waits for inflight ones to finish before returning. +func (e *Endpoint) WaitWrite() { + e.writeGate.Close() +} + +// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the +// actual dispatcher, and waits for inflight ones to finish before returning. +func (e *Endpoint) WaitDispatch() { + e.dispatchGate.Close() +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go new file mode 100644 index 000000000..cb433dc19 --- /dev/null +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -0,0 +1,144 @@ +// Copyright 2017 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package waitable + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +type countedEndpoint struct { + dispatchCount int + writeCount int + attachCount int + + mtu uint32 + capabilities stack.LinkEndpointCapabilities + hdrLen uint16 + linkAddr tcpip.LinkAddress + + dispatcher stack.NetworkDispatcher +} + +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + e.dispatchCount++ +} + +func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.attachCount++ + e.dispatcher = dispatcher +} + +func (e *countedEndpoint) MTU() uint32 { + return e.mtu +} + +func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.capabilities +} + +func (e *countedEndpoint) MaxHeaderLength() uint16 { + return e.hdrLen +} + +func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e *countedEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + e.writeCount++ + return nil +} + +func TestWaitWrite(t *testing.T) { + ep := &countedEndpoint{} + _, wep := New(stack.RegisterLinkEndpoint(ep)) + + // Write and check that it goes through. + wep.WritePacket(nil, nil, nil, 0) + if want := 1; ep.writeCount != want { + t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) + } + + // Wait on dispatches, then try to write. It must go through. + wep.WaitDispatch() + wep.WritePacket(nil, nil, nil, 0) + if want := 2; ep.writeCount != want { + t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) + } + + // Wait on writes, then try to write. It must not go through. + wep.WaitWrite() + wep.WritePacket(nil, nil, nil, 0) + if want := 2; ep.writeCount != want { + t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) + } +} + +func TestWaitDispatch(t *testing.T) { + ep := &countedEndpoint{} + _, wep := New(stack.RegisterLinkEndpoint(ep)) + + // Check that attach happens. + wep.Attach(ep) + if want := 1; ep.attachCount != want { + t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want) + } + + // Dispatch and check that it goes through. + ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil) + if want := 1; ep.dispatchCount != want { + t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) + } + + // Wait on writes, then try to dispatch. It must go through. + wep.WaitWrite() + ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil) + if want := 2; ep.dispatchCount != want { + t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) + } + + // Wait on dispatches, then try to dispatch. It must not go through. + wep.WaitDispatch() + ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil) + if want := 2; ep.dispatchCount != want { + t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) + } +} + +func TestOtherMethods(t *testing.T) { + const ( + mtu = 0xdead + capabilities = 0xbeef + hdrLen = 0x1234 + linkAddr = "test address" + ) + ep := &countedEndpoint{ + mtu: mtu, + capabilities: capabilities, + hdrLen: hdrLen, + linkAddr: linkAddr, + } + _, wep := New(stack.RegisterLinkEndpoint(ep)) + + if v := wep.MTU(); v != mtu { + t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu) + } + + if v := wep.Capabilities(); v != capabilities { + t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities) + } + + if v := wep.MaxHeaderLength(); v != hdrLen { + t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen) + } + + if v := wep.LinkAddress(); v != linkAddr { + t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr) + } +} -- cgit v1.2.3