summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/link/fdbased
diff options
context:
space:
mode:
authorTing-Yu Wang <anivia@google.com>2020-08-13 13:07:03 -0700
committergVisor bot <gvisor-bot@google.com>2020-08-13 13:08:57 -0700
commit47515f475167ffa23267ca0b9d1b39e7907587d6 (patch)
tree595ed3020846d93746778d9ac2ca5121f9e880d1 /pkg/tcpip/link/fdbased
parentb928d074b461c6f2578c989e48adadc951ed3154 (diff)
Migrate to PacketHeader API for PacketBuffer.
Formerly, when a packet is constructed or parsed, all headers are set by the client code. This almost always involved prepending to pk.Header buffer or trimming pk.Data portion. This is known to prone to bugs, due to the complexity and number of the invariants assumed across netstack to maintain. In the new PacketHeader API, client will call Push()/Consume() method to construct/parse an outgoing/incoming packet. All invariants, such as slicing and trimming, are maintained by the API itself. NewPacketBuffer() is introduced to create new PacketBuffer. Zero value is no longer valid. PacketBuffer now assumes the packet is a concatenation of following portions: * LinkHeader * NetworkHeader * TransportHeader * Data Any of them could be empty, or zero-length. PiperOrigin-RevId: 326507688
Diffstat (limited to 'pkg/tcpip/link/fdbased')
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go14
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go134
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go16
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go45
5 files changed, 117 insertions, 93 deletions
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 507b44abc..10072eac1 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -37,5 +37,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
"//pkg/tcpip/stack",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index c18bb91fb..975309fc8 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -390,8 +390,7 @@ const (
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
- eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize))
- pkt.LinkHeader = buffer.View(eth)
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
ethHdr := &header.EthernetFields{
DstAddr: remote,
Type: protocol,
@@ -420,7 +419,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
if gso != nil {
- vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
if gso.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
@@ -443,11 +442,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
builder.Add(vnetHdrBuf)
}
- builder.Add(pkt.Header.View())
- for _, v := range pkt.Data.Views() {
+ for _, v := range pkt.Views() {
builder.Add(v)
}
-
return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
@@ -463,7 +460,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions != nil {
- vnetHdr.hdrLen = uint16(pkt.Header.UsedLength())
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
if pkt.GSOOptions.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
@@ -486,8 +483,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
var builder iovec.Builder
builder.Add(vnetHdrBuf)
- builder.Add(pkt.Header.View())
- for _, v := range pkt.Data.Views() {
+ for _, v := range pkt.Views() {
builder.Add(v)
}
iovecs := builder.Build()
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 7b995b85a..709f829c8 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -26,6 +26,7 @@ import (
"time"
"unsafe"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -43,9 +44,36 @@ const (
)
type packetInfo struct {
- raddr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- contents *stack.PacketBuffer
+ Raddr tcpip.LinkAddress
+ Proto tcpip.NetworkProtocolNumber
+ Contents *stack.PacketBuffer
+}
+
+type packetContents struct {
+ LinkHeader buffer.View
+ NetworkHeader buffer.View
+ TransportHeader buffer.View
+ Data buffer.View
+}
+
+func checkPacketInfoEqual(t *testing.T, got, want packetInfo) {
+ t.Helper()
+ if diff := cmp.Diff(
+ want, got,
+ cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents {
+ if pk == nil {
+ return nil
+ }
+ return &packetContents{
+ LinkHeader: pk.LinkHeader().View(),
+ NetworkHeader: pk.NetworkHeader().View(),
+ TransportHeader: pk.TransportHeader().View(),
+ Data: pk.Data.ToView(),
+ }
+ }),
+ ); diff != "" {
+ t.Errorf("unexpected packetInfo (-want +got):\n%s", diff)
+ }
}
type context struct {
@@ -159,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
RemoteLinkAddress: raddr,
}
- // Build header.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
- b := hdr.Prepend(100)
- for i := range b {
- b[i] = uint8(rand.Intn(256))
+ // Build payload.
+ payload := buffer.NewView(plen)
+ if _, err := rand.Read(payload); err != nil {
+ t.Fatalf("rand.Read(payload): %s", err)
}
- // Build payload and write.
- payload := make(buffer.View, plen)
- for i := range payload {
- payload[i] = uint8(rand.Intn(256))
+ // Build packet buffer.
+ const netHdrLen = 100
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen,
+ Data: payload.ToVectorisedView(),
+ })
+ pkt.Hash = hash
+
+ // Build header.
+ b := pkt.NetworkHeader().Push(netHdrLen)
+ if _, err := rand.Read(b); err != nil {
+ t.Fatalf("rand.Read(b): %s", err)
}
- want := append(hdr.View(), payload...)
+
+ // Write.
+ want := append(append(buffer.View(nil), b...), payload...)
var gso *stack.GSO
if gsoMaxSize != 0 {
gso = &stack.GSO{
@@ -183,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- Hash: hash,
- }); err != nil {
+ if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -296,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) {
LocalLinkAddress: baddr,
}
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.VectorisedView{},
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength().
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ Data: buffer.VectorisedView{},
+ })
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -331,24 +365,25 @@ func TestDeliverPacket(t *testing.T) {
defer c.cleanup()
// Build packet.
- b := make([]byte, plen)
- all := b
- for i := range b {
- b[i] = uint8(rand.Intn(256))
+ all := make([]byte, plen)
+ if _, err := rand.Read(all); err != nil {
+ t.Fatalf("rand.Read(all): %s", err)
}
-
- var hdr header.Ethernet
- if !eth {
- // So that it looks like an IPv4 packet.
- b[0] = 0x40
- } else {
- hdr = make(header.Ethernet, header.EthernetMinimumSize)
+ // Make it look like an IPv4 packet.
+ all[0] = 0x40
+
+ wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ Data: buffer.NewViewFromBytes(all).ToVectorisedView(),
+ })
+ if eth {
+ hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize))
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
Type: proto,
})
- all = append(hdr, b...)
+ all = append(hdr, all...)
}
// Write packet via the file descriptor.
@@ -360,24 +395,15 @@ func TestDeliverPacket(t *testing.T) {
select {
case pi := <-c.ch:
want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: &stack.PacketBuffer{
- Data: buffer.View(b).ToVectorisedView(),
- LinkHeader: buffer.View(hdr),
- },
+ Raddr: raddr,
+ Proto: proto,
+ Contents: wantPkt,
}
if !eth {
- want.proto = header.IPv4ProtocolNumber
- want.raddr = ""
- }
- // want.contents.Data will be a single
- // view, so make pi do the same for the
- // DeepEqual check.
- pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView()
- if !reflect.DeepEqual(want, pi) {
- t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
+ want.Proto = header.IPv4ProtocolNumber
+ want.Raddr = ""
}
+ checkPacketInfoEqual(t, pi, want)
case <-time.After(10 * time.Second):
t.Fatalf("Timed out waiting for packet")
}
@@ -572,8 +598,8 @@ func TestDispatchPacketFormat(t *testing.T) {
t.Fatalf("len(sink.pkts) = %d, want %d", got, want)
}
pkt := sink.pkts[0]
- if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want {
- t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want)
+ if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want {
+ t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want)
}
if got, want := pkt.Data.Size(), 4; got != want {
t.Errorf("pkt.Data.Size() = %d, want %d", got, want)
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 2dfd29aa9..c475dda20 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -18,6 +18,7 @@ package fdbased
import (
"encoding/binary"
+ "fmt"
"syscall"
"golang.org/x/sys/unix"
@@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(pkt)
+ eth := header.Ethernet(pkt)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{
- Data: buffer.View(pkt).ToVectorisedView(),
- LinkHeader: buffer.View(eth),
+ pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(pkt).ToVectorisedView(),
})
+ if d.e.hdrSize > 0 {
+ if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok {
+ panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize))
+ }
+ }
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf)
return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index d8f2504b3..8c3ca86d6 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
d.allocateViews(BufConfig)
n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
- if err != nil {
+ if n == 0 || err != nil {
return false, err
}
if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
@@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
// isn't used and it isn't in a view.
n -= virtioNetHdrSize
}
- if n <= d.e.hdrSize {
- return false, nil
- }
+
+ used := d.capViews(n, BufConfig)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ })
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
+ hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
+ if !ok {
+ return false, nil
+ }
+ eth := header.Ethernet(hdr)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- used := d.capViews(n, BufConfig)
- pkt := &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
- LinkHeader: buffer.View(eth),
- }
- pkt.Data.TrimFront(d.e.hdrSize)
-
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
@@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
n -= virtioNetHdrSize
}
- if n <= d.e.hdrSize {
- return false, nil
- }
+
+ used := d.capViews(k, int(n), BufConfig)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ })
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize])
+ hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
+ if !ok {
+ return false, nil
+ }
+ eth := header.Ethernet(hdr)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- used := d.capViews(k, int(n), BufConfig)
- pkt := &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
- LinkHeader: buffer.View(eth),
- }
- pkt.Data.TrimFront(d.e.hdrSize)
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.