summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/conntrack.go31
-rw-r--r--pkg/tcpip/stack/forwarder_test.go65
-rw-r--r--pkg/tcpip/stack/headertype_string.go39
-rw-r--r--pkg/tcpip/stack/iptables.go2
-rw-r--r--pkg/tcpip/stack/iptables_targets.go9
-rw-r--r--pkg/tcpip/stack/ndp.go32
-rw-r--r--pkg/tcpip/stack/ndp_test.go22
-rw-r--r--pkg/tcpip/stack/nic.go40
-rw-r--r--pkg/tcpip/stack/nic_test.go4
-rw-r--r--pkg/tcpip/stack/packet_buffer.go270
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go397
-rw-r--r--pkg/tcpip/stack/registration.go4
-rw-r--r--pkg/tcpip/stack/route.go5
-rw-r--r--pkg/tcpip/stack/stack_test.go61
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go14
-rw-r--r--pkg/tcpip/stack/transport_test.go54
17 files changed, 835 insertions, 216 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index bfc7a0c7c..900938dd1 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -57,6 +57,7 @@ go_library(
"conntrack.go",
"dhcpv6configurationfromndpra_string.go",
"forwarder.go",
+ "headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
"iptables_state.go",
@@ -143,6 +144,7 @@ go_test(
"neighbor_cache_test.go",
"neighbor_entry_test.go",
"nic_test.go",
+ "packet_buffer_test.go",
],
library = ":stack",
deps = [
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 470c265aa..7dd344b4f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -199,12 +199,12 @@ type bucket struct {
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
// TODO(gvisor.dev/issue/170): Need to support for other
// protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader)
- if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
}
@@ -344,8 +344,8 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
// have their destinations modified and replies have their sources
@@ -377,8 +377,8 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader)
- tcpHeader := header.TCP(pkt.TransportHeader)
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
// have their destinations modified and replies have their sources
@@ -396,8 +396,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- hdr := &pkt.Header
- length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -423,7 +422,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -433,8 +432,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
return true
}
- tcpHeader := header.TCP(pkt.TransportHeader)
- if tcpHeader == nil {
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
return false
}
@@ -455,7 +454,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
// Mark the connection as having been used recently so it isn't reaped.
conn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
return false
}
@@ -474,7 +473,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber {
+ if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -486,7 +485,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
return
}
conn := newConn(tid, tid.reply(), manipNone, hook)
- conn.updateLocked(header.TCP(pkt.TransportHeader), hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
ct.insertConn(conn)
}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index c962693f5..944f622fd 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -75,7 +75,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -97,7 +97,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := pkt.Header.Prepend(fwdTestNetHeaderLen)
+ b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
b[dstAddrOffset] = r.RemoteAddress[0]
b[srcAddrOffset] = f.id.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
@@ -144,13 +144,11 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add
}
func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen)
if !ok {
return 0, false, false
}
- pkt.NetworkHeader = netHeader
- pkt.Data.TrimFront(fwdTestNetHeaderLen)
- return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
+ return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
@@ -290,7 +288,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
- Pkt: &PacketBuffer{Data: vv},
+ Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
}
select {
@@ -382,9 +380,9 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
var p fwdTestPacketInfo
@@ -419,9 +417,9 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
var p fwdTestPacketInfo
@@ -450,9 +448,9 @@ func TestForwardingWithNoResolver(t *testing.T) {
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
select {
case <-ep2.C:
@@ -480,17 +478,17 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// not be forwarded.
buf := buffer.NewView(30)
buf[dstAddrOffset] = 4
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
var p fwdTestPacketInfo
@@ -500,8 +498,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -529,9 +527,9 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
for i := 0; i < 2; i++ {
@@ -543,8 +541,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -575,9 +573,9 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
buf[dstAddrOffset] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
for i := 0; i < maxPendingPacketsPerResolution; i++ {
@@ -589,13 +587,14 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
+ b := PayloadSince(p.Pkt.NetworkHeader())
+ if b[dstAddrOffset] != 3 {
t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
}
- seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
- if !ok {
- t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
+ if len(b) < fwdTestNetHeaderLen+2 {
+ t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
}
+ seqNumBuf := b[fwdTestNetHeaderLen:]
// The first 5 packets should not be forwarded so the sequence number should
// start with 5.
@@ -632,9 +631,9 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
buf[dstAddrOffset] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
}
for i := 0; i < maxPendingResolutions; i++ {
@@ -648,8 +647,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
}
// Test that the address resolution happened correctly.
diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go
new file mode 100644
index 000000000..5efddfaaf
--- /dev/null
+++ b/pkg/tcpip/stack/headertype_string.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type headerType ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[linkHeader-0]
+ _ = x[networkHeader-1]
+ _ = x[transportHeader-2]
+ _ = x[numHeaderType-3]
+}
+
+const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType"
+
+var _headerType_index = [...]uint8{0, 10, 23, 38, 51}
+
+func (i headerType) String() string {
+ if i < 0 || i >= headerType(len(_headerType_index)-1) {
+ return "headerType(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _headerType_name[_headerType_index[i]:_headerType_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 110ba073d..c37da814f 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -394,7 +394,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) {
+ if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index dc88033c7..5f1b2af64 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -99,7 +99,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
}
// Drop the packet if network and transport header are not set.
- if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
return RuleDrop, 0
}
@@ -118,17 +118,16 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader)
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
switch protocol := netHeader.TransportProtocol(); protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader)
+ udpHeader := header.UDP(pkt.TransportHeader().View())
udpHeader.SetDestinationPort(rt.MinPort)
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- hdr := &pkt.Header
- length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 5174e639c..93567806b 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -746,12 +746,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmpData.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
ns.SetTargetAddress(addr)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
@@ -759,7 +763,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, &PacketBuffer{Header: hdr},
+ }, pkt,
); err != nil {
sent.Dropped.Increment()
return err
@@ -1897,12 +1901,16 @@ func (ndp *ndpState) startSolicitingRouters() {
}
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
- pkt := header.ICMPv6(hdr.Prepend(payloadSize))
- pkt.SetType(header.ICMPv6RouterSolicit)
- rs := header.NDPRouterSolicit(pkt.NDPPayload())
+ icmpData := header.ICMPv6(buffer.NewView(payloadSize))
+ icmpData.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(icmpData.NDPPayload())
rs.Options().Serialize(optsSerializer)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
sent := r.Stats().ICMP.V6PacketsSent
if err := r.WritePacket(nil,
@@ -1910,7 +1918,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, &PacketBuffer{Header: hdr},
+ }, pkt,
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 5d286ccbc..21bf53010 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -541,7 +541,7 @@ func TestDADResolve(t *testing.T) {
// As per RFC 4861 section 4.3, a possible option is the Source Link
// Layer option, but this option MUST NOT be included when the source
// address of the packet is the unspecified address.
- checker.IPv6(t, p.Pkt.Header.View(),
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
@@ -550,8 +550,8 @@ func TestDADResolve(t *testing.T) {
checker.NDPNSOptions(nil),
))
- if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
}
})
@@ -667,9 +667,10 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
+ e.InjectInbound(header.IPv6ProtocolNumber, pkt)
stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
if got := stat.Value(); got != 1 {
@@ -1024,7 +1025,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
@@ -5134,16 +5137,15 @@ func TestRouterSolicitation(t *testing.T) {
t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
- checker.IPv6(t,
- p.Pkt.Header.View(),
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(test.expectedSrcAddr),
checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
)
- if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
}
}
waitForNothing := func(timeout time.Duration) {
@@ -5288,7 +5290,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
if p.Proto != header.IPv6ProtocolNumber {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
- checker.IPv6(t, p.Pkt.Header.View(),
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
checker.TTL(header.NDPHopLimit),
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index eaaf756cd..2315ea5b9 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1299,7 +1299,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
}
- src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1401,24 +1401,19 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
- // by having lower layers explicity write each header instead of just
- // pkt.Header.
- // pkt may have set its NetworkHeader and TransportHeader. If we're
- // forwarding, we'll have to copy them into pkt.Header.
- pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
- if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
- panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
- }
- if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
- panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
- }
+ // pkt may have set its header and may not have enough headroom for link-layer
+ // header for the other link to prepend. Here we create a new packet to
+ // forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+ // WritePacket takes ownership of fwdPkt, calculate numBytes first.
+ numBytes := fwdPkt.Size()
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
}
@@ -1443,34 +1438,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
- if pkt.TransportHeader == nil {
+ if pkt.TransportHeader().View().IsEmpty() {
// TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
// fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
// full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
// ICMP packets may be longer, but until icmp.Parse is implemented, here
// we parse it using the minimum size.
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- pkt.TransportHeader = transHeader
- pkt.Data.TrimFront(len(pkt.TransportHeader))
} else {
// This is either a bad packet or was re-assembled from fragments.
transProto.Parse(pkt)
}
}
- if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
+ if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index a70792b50..0870c8d9c 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -311,7 +311,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(),
+ }))
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 5d6865e35..17b8beebb 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -14,16 +14,43 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
+type headerType int
+
+const (
+ linkHeader headerType = iota
+ networkHeader
+ transportHeader
+ numHeaderType
+)
+
+// PacketBufferOptions specifies options for PacketBuffer creation.
+type PacketBufferOptions struct {
+ // ReserveHeaderBytes is the number of bytes to reserve for headers. Total
+ // number of bytes pushed onto the headers must not exceed this value.
+ ReserveHeaderBytes int
+
+ // Data is the initial unparsed data for the new packet. If set, it will be
+ // owned by the new packet.
+ Data buffer.VectorisedView
+}
+
// A PacketBuffer contains all the data of a network packet.
//
// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
-// multiple endpoints. Clone() should be called in such cases so that
-// modifications to the Data field do not affect other copies.
+// multiple endpoints.
+//
+// The whole packet is expected to be a series of bytes in the following order:
+// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be
+// empty. Use of PacketBuffer in any other order is unsupported.
+//
+// PacketBuffer must be created with NewPacketBuffer.
type PacketBuffer struct {
_ sync.NoCopy
@@ -31,36 +58,32 @@ type PacketBuffer struct {
// PacketBuffers.
PacketBufferEntry
- // Data holds the payload of the packet. For inbound packets, it also
- // holds the headers, which are consumed as the packet moves up the
- // stack. Headers are guaranteed not to be split across views.
+ // Data holds the payload of the packet.
+ //
+ // For inbound packets, Data is initially the whole packet. Then gets moved to
+ // headers via PacketHeader.Consume, when the packet is being parsed.
+ //
+ // For outbound packets, Data is the innermost layer, defined by the protocol.
+ // Headers are pushed in front of it via PacketHeader.Push.
//
- // The bytes backing Data are immutable, but Data itself may be trimmed
- // or otherwise modified.
+ // The bytes backing Data are immutable, a.k.a. users shouldn't write to its
+ // backing storage.
Data buffer.VectorisedView
- // Header holds the headers of outbound packets. As a packet is passed
- // down the stack, each layer adds to Header. Note that forwarded
- // packets don't populate Headers on their way out -- their headers and
- // payload are never parsed out and remain in Data.
- //
- // TODO(gvisor.dev/issue/170): Forwarded packets don't currently
- // populate Header, but should. This will be doable once early parsing
- // (https://github.com/google/gvisor/pull/1995) is supported.
- Header buffer.Prependable
+ // headers stores metadata about each header.
+ headers [numHeaderType]headerInfo
- // These fields are used by both inbound and outbound packets. They
- // typically overlap with the Data and Header fields.
- //
- // The bytes backing these views are immutable. Each field may be nil
- // if either it has not been set yet or no such header exists (e.g.
- // packets sent via loopback may not have a link header).
+ // header is the internal storage for outbound packets. Headers will be pushed
+ // (prepended) on this storage as the packet is being constructed.
//
- // These fields may be Views into other slices (either Data or Header).
- // SR dosen't support this, so deep copies are necessary in some cases.
- LinkHeader buffer.View
- NetworkHeader buffer.View
- TransportHeader buffer.View
+ // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and
+ // data are held in the same underlying buffer storage.
+ header buffer.Prependable
+
+ // NetworkProtocol is only valid when NetworkHeader is set.
+ // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
+ // numbers in registration APIs that take a PacketBuffer.
+ NetworkProtocolNumber tcpip.NetworkProtocolNumber
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
@@ -72,9 +95,8 @@ type PacketBuffer struct {
// The following fields are only set by the qdisc layer when the packet
// is added to a queue.
- EgressRoute *Route
- GSOOptions *GSO
- NetworkProtocolNumber tcpip.NetworkProtocolNumber
+ EgressRoute *Route
+ GSOOptions *GSO
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
@@ -85,20 +107,137 @@ type PacketBuffer struct {
PktType tcpip.PacketType
}
-// Clone makes a copy of pk. It clones the Data field, which creates a new
-// VectorisedView but does not deep copy the underlying bytes.
-//
-// Clone also does not deep copy any of its other fields.
+// NewPacketBuffer creates a new PacketBuffer with opts.
+func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
+ pk := &PacketBuffer{
+ Data: opts.Data,
+ }
+ if opts.ReserveHeaderBytes != 0 {
+ pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
+ }
+ return pk
+}
+
+// ReservedHeaderBytes returns the number of bytes initially reserved for
+// headers.
+func (pk *PacketBuffer) ReservedHeaderBytes() int {
+ return pk.header.UsedLength() + pk.header.AvailableLength()
+}
+
+// AvailableHeaderBytes returns the number of bytes currently available for
+// headers. This is relevant to PacketHeader.Push method only.
+func (pk *PacketBuffer) AvailableHeaderBytes() int {
+ return pk.header.AvailableLength()
+}
+
+// LinkHeader returns the handle to link-layer header.
+func (pk *PacketBuffer) LinkHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: linkHeader,
+ }
+}
+
+// NetworkHeader returns the handle to network-layer header.
+func (pk *PacketBuffer) NetworkHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: networkHeader,
+ }
+}
+
+// TransportHeader returns the handle to transport-layer header.
+func (pk *PacketBuffer) TransportHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: transportHeader,
+ }
+}
+
+// HeaderSize returns the total size of all headers in bytes.
+func (pk *PacketBuffer) HeaderSize() int {
+ // Note for inbound packets (Consume called), headers are not stored in
+ // pk.header. Thus, calculation of size of each header is needed.
+ var size int
+ for i := range pk.headers {
+ size += len(pk.headers[i].buf)
+ }
+ return size
+}
+
+// Size returns the size of packet in bytes.
+func (pk *PacketBuffer) Size() int {
+ return pk.HeaderSize() + pk.Data.Size()
+}
+
+// Views returns the underlying storage of the whole packet.
+func (pk *PacketBuffer) Views() []buffer.View {
+ // Optimization for outbound packets that headers are in pk.header.
+ useHeader := true
+ for i := range pk.headers {
+ if !canUseHeader(&pk.headers[i]) {
+ useHeader = false
+ break
+ }
+ }
+
+ dataViews := pk.Data.Views()
+
+ var vs []buffer.View
+ if useHeader {
+ vs = make([]buffer.View, 0, 1+len(dataViews))
+ vs = append(vs, pk.header.View())
+ } else {
+ vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews))
+ for i := range pk.headers {
+ if v := pk.headers[i].buf; len(v) > 0 {
+ vs = append(vs, v)
+ }
+ }
+ }
+ return append(vs, dataViews...)
+}
+
+func canUseHeader(h *headerInfo) bool {
+ // h.offset will be negative if the header was pushed in to prependable
+ // portion, or doesn't matter when it's empty.
+ return len(h.buf) == 0 || h.offset < 0
+}
+
+func (pk *PacketBuffer) push(typ headerType, size int) buffer.View {
+ h := &pk.headers[typ]
+ if h.buf != nil {
+ panic(fmt.Sprintf("push must not be called twice: type %s", typ))
+ }
+ h.buf = buffer.View(pk.header.Prepend(size))
+ h.offset = -pk.header.UsedLength()
+ return h.buf
+}
+
+func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) {
+ h := &pk.headers[typ]
+ if h.buf != nil {
+ panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
+ }
+ v, ok := pk.Data.PullUp(size)
+ if !ok {
+ return
+ }
+ pk.Data.TrimFront(size)
+ h.buf = v
+ return h.buf, true
+}
+
+// Clone makes a shallow copy of pk.
//
-// FIXME(b/153685824): Data gets copied but not other header references.
+// Clone should be called in such cases so that no modifications is done to
+// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- return &PacketBuffer{
+ newPk := &PacketBuffer{
PacketBufferEntry: pk.PacketBufferEntry,
Data: pk.Data.Clone(nil),
- Header: pk.Header,
- LinkHeader: pk.LinkHeader,
- NetworkHeader: pk.NetworkHeader,
- TransportHeader: pk.TransportHeader,
+ headers: pk.headers,
+ header: pk.header,
Hash: pk.Hash,
Owner: pk.Owner,
EgressRoute: pk.EgressRoute,
@@ -106,4 +245,55 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
NetworkProtocolNumber: pk.NetworkProtocolNumber,
NatDone: pk.NatDone,
}
+ return newPk
+}
+
+// headerInfo stores metadata about a header in a packet.
+type headerInfo struct {
+ // buf is the memorized slice for both prepended and consumed header.
+ // When header is prepended, buf serves as memorized value, which is a slice
+ // of pk.header. When header is consumed, buf is the slice pulled out from
+ // pk.Data, which is the only place to hold this header.
+ buf buffer.View
+
+ // offset will be a negative number denoting the offset where this header is
+ // from the end of pk.header, if it is prepended. Otherwise, zero.
+ offset int
+}
+
+// PacketHeader is a handle object to a header in the underlying packet.
+type PacketHeader struct {
+ pk *PacketBuffer
+ typ headerType
+}
+
+// View returns the underlying storage of h.
+func (h PacketHeader) View() buffer.View {
+ return h.pk.headers[h.typ].buf
+}
+
+// Push pushes size bytes in the front of its residing packet, and returns the
+// backing storage. Callers may only call one of Push or Consume once on each
+// header in the lifetime of the underlying packet.
+func (h PacketHeader) Push(size int) buffer.View {
+ return h.pk.push(h.typ, size)
+}
+
+// Consume moves the first size bytes of the unparsed data portion in the packet
+// to h, and returns the backing storage. In the case of data is shorter than
+// size, consumed will be false, and the state of h will not be affected.
+// Callers may only call one of Push or Consume once on each header in the
+// lifetime of the underlying packet.
+func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
+ return h.pk.consume(h.typ, size)
+}
+
+// PayloadSince returns packet payload starting from and including a particular
+// header. This method isn't optimized and should be used in test only.
+func PayloadSince(h PacketHeader) buffer.View {
+ var v buffer.View
+ for _, hinfo := range h.pk.headers[h.typ:] {
+ v = append(v, hinfo.buf...)
+ }
+ return append(v, h.pk.Data.ToView()...)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
new file mode 100644
index 000000000..c6fa8da5f
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func TestPacketHeaderPush(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reserved int
+ link []byte
+ network []byte
+ transport []byte
+ data []byte
+ }{
+ {
+ name: "construct empty packet",
+ },
+ {
+ name: "construct link header only packet",
+ reserved: 60,
+ link: makeView(10),
+ },
+ {
+ name: "construct link and network header only packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ },
+ {
+ name: "construct header only packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ transport: makeView(30),
+ },
+ {
+ name: "construct data only packet",
+ data: makeView(40),
+ },
+ {
+ name: "construct L3 packet",
+ reserved: 60,
+ network: makeView(20),
+ transport: makeView(30),
+ data: makeView(40),
+ },
+ {
+ name: "construct L2 packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ transport: makeView(30),
+ data: makeView(40),
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: test.reserved,
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
+ })
+
+ allHdrSize := len(test.link) + len(test.network) + len(test.transport)
+
+ // Check the initial values for packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ ReserveHeaderBytes: test.reserved,
+ Data: buffer.View(test.data).ToVectorisedView(),
+ })
+
+ // Push headers.
+ if v := test.transport; len(v) > 0 {
+ copy(pk.TransportHeader().Push(len(v)), v)
+ }
+ if v := test.network; len(v) > 0 {
+ copy(pk.NetworkHeader().Push(len(v)), v)
+ }
+ if v := test.link; len(v) > 0 {
+ copy(pk.LinkHeader().Push(len(v)), v)
+ }
+
+ // Check the after values for packet.
+ if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want {
+ t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want {
+ t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), allHdrSize; got != want {
+ t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
+ }
+ if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
+ t.Errorf("After pk.Size() = %d, want %d", got, want)
+ }
+ checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data)
+ checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
+ concatViews(test.link, test.network, test.transport, test.data))
+ // Check the after values for each header.
+ checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
+ checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
+ checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
+ // Check the after values for PayloadSince.
+ checkViewEqual(t, "After PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(test.link, test.network, test.transport, test.data))
+ checkViewEqual(t, "After PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(test.network, test.transport, test.data))
+ checkViewEqual(t, "After PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(test.transport, test.data))
+ })
+ }
+}
+
+func TestPacketHeaderConsume(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ data []byte
+ link int
+ network int
+ transport int
+ }{
+ {
+ name: "parse L2 packet",
+ data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)),
+ link: 10,
+ network: 20,
+ transport: 30,
+ },
+ {
+ name: "parse L3 packet",
+ data: concatViews(makeView(20), makeView(30), makeView(40)),
+ network: 20,
+ transport: 30,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
+ })
+
+ // Check the initial values for packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ Data: buffer.View(test.data).ToVectorisedView(),
+ })
+
+ // Consume headers.
+ if size := test.link; size > 0 {
+ if _, ok := pk.LinkHeader().Consume(size); !ok {
+ t.Fatalf("pk.LinkHeader().Consume() = false, want true")
+ }
+ }
+ if size := test.network; size > 0 {
+ if _, ok := pk.NetworkHeader().Consume(size); !ok {
+ t.Fatalf("pk.NetworkHeader().Consume() = false, want true")
+ }
+ }
+ if size := test.transport; size > 0 {
+ if _, ok := pk.TransportHeader().Consume(size); !ok {
+ t.Fatalf("pk.TransportHeader().Consume() = false, want true")
+ }
+ }
+
+ allHdrSize := test.link + test.network + test.transport
+
+ // Check the after values for packet.
+ if got, want := pk.ReservedHeaderBytes(), 0; got != want {
+ t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), 0; got != want {
+ t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), allHdrSize; got != want {
+ t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
+ }
+ if got, want := pk.Size(), len(test.data); got != want {
+ t.Errorf("After pk.Size() = %d, want %d", got, want)
+ }
+ // After state of pk.
+ var (
+ link = test.data[:test.link]
+ network = test.data[test.link:][:test.network]
+ transport = test.data[test.link+test.network:][:test.transport]
+ payload = test.data[allHdrSize:]
+ )
+ checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload)
+ checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
+ // Check the after values for each header.
+ checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
+ checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
+ checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
+ // Check the after values for PayloadSince.
+ checkViewEqual(t, "After PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(link, network, transport, payload))
+ checkViewEqual(t, "After PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(network, transport, payload))
+ checkViewEqual(t, "After PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(transport, payload))
+ })
+ }
+}
+
+func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
+ data := makeView(10)
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
+ })
+
+ // Consume should fail if pkt.Data is too short.
+ if _, ok := pk.LinkHeader().Consume(11); ok {
+ t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false")
+ }
+ if _, ok := pk.NetworkHeader().Consume(11); ok {
+ t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false")
+ }
+ if _, ok := pk.TransportHeader().Consume(11); ok {
+ t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false")
+ }
+
+ // Check packet should look the same as initial packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ Data: buffer.View(data).ToVectorisedView(),
+ })
+}
+
+func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: headerSize * int(numHeaderType),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.TransportHeader(),
+ pk.NetworkHeader(),
+ pk.LinkHeader(),
+ } {
+ t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) {
+ h.Push(headerSize)
+
+ defer func() { recover() }()
+ h.Push(headerSize)
+ t.Fatal("Second push should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.LinkHeader(),
+ pk.NetworkHeader(),
+ pk.TransportHeader(),
+ } {
+ t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) {
+ if _, ok := h.Consume(headerSize); !ok {
+ t.Fatal("First consume should succeed")
+ }
+
+ defer func() { recover() }()
+ h.Consume(headerSize)
+ t.Fatal("Second consume should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderPushThenConsumePanics(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: headerSize * int(numHeaderType),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.TransportHeader(),
+ pk.NetworkHeader(),
+ pk.LinkHeader(),
+ } {
+ t.Run(h.typ.String(), func(t *testing.T) {
+ h.Push(headerSize)
+
+ defer func() { recover() }()
+ h.Consume(headerSize)
+ t.Fatal("Consume should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.LinkHeader(),
+ pk.NetworkHeader(),
+ pk.TransportHeader(),
+ } {
+ t.Run(h.typ.String(), func(t *testing.T) {
+ h.Consume(headerSize)
+
+ defer func() { recover() }()
+ h.Push(headerSize)
+ t.Fatal("Push should have panicked")
+ })
+ }
+}
+
+func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
+ t.Helper()
+ reserved := opts.ReserveHeaderBytes
+ if got, want := pk.ReservedHeaderBytes(), reserved; got != want {
+ t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), reserved; got != want {
+ t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), 0; got != want {
+ t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want)
+ }
+ data := opts.Data.ToView()
+ if got, want := pk.Size(), len(data); got != want {
+ t.Errorf("Initial pk.Size() = %d, want %d", got, want)
+ }
+ checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data)
+ checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
+ // Check the initial values for each header.
+ checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
+ checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
+ checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
+ // Check the initial valies for PayloadSince.
+ checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()), data)
+ checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()), data)
+ checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()), data)
+}
+
+func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
+ t.Helper()
+ checkViewEqual(t, name+".View()", h.View(), want)
+}
+
+func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
+ t.Helper()
+ if !bytes.Equal(got, want) {
+ t.Errorf("%s = %x, want %x", what, got, want)
+ }
+}
+
+func makeView(size int) buffer.View {
+ b := byte(size)
+ return bytes.Repeat([]byte{b}, size)
+}
+
+func concatViews(views ...buffer.View) buffer.View {
+ var all buffer.View
+ for _, v := range views {
+ all = append(all, v...)
+ }
+ return all
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 8604c4259..4570e8969 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -249,8 +249,8 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol. It takes ownership of pkt. pkt.TransportHeader must have already
- // been set.
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have
+ // already been set.
WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 9ce0a2c22..e267bebb0 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -173,7 +173,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
}
// WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Header.UsedLength() + pkt.Data.Size()
+ numBytes := pkt.Size()
err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
@@ -203,8 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Header.UsedLength()
- writtenBytes += pb.Data.Size()
+ writtenBytes += pb.Size()
}
r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index fe1c1b8a4..0273b3c63 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -102,7 +102,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -118,7 +118,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -143,10 +143,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
- pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
- pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
- pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
- pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
+ hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ hdr[dstAddrOffset] = r.RemoteAddress[0]
+ hdr[srcAddrOffset] = f.id.LocalAddress[0]
+ hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
f.HandlePacket(r, pkt)
@@ -249,12 +249,10 @@ func (*fakeNetworkProtocol) Wait() {}
// Parse implements TransportProtocol.Parse.
func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
- hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
if !ok {
return 0, false, false
}
- pkt.NetworkHeader = hdr
- pkt.Data.TrimFront(fakeNetHeaderLen)
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
@@ -315,9 +313,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[dstAddrOffset] = 3
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -327,9 +325,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[dstAddrOffset] = 1
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -339,9 +337,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[dstAddrOffset] = 2
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -350,9 +348,9 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -362,9 +360,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -383,11 +381,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
}
func send(r stack.Route, payload buffer.View) *tcpip.Error {
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: payload.ToVectorisedView(),
- })
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ }))
}
func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
@@ -442,9 +439,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if got := fakeNet.PacketCount(localAddrByte); got != want {
t.Errorf("receive packet count: got = %d, want %d", got, want)
}
@@ -2285,9 +2282,9 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -2367,9 +2364,9 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
buf[dstAddrOffset] = dstAddr[0]
- ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
pkt, ok := ep2.Read()
if !ok {
@@ -2377,8 +2374,8 @@ func TestNICForwarding(t *testing.T) {
}
// Test that the link's MaxHeaderLength is honoured.
- if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
- t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
+ t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
}
// Test that forwarding increments Tx stats correctly.
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 73dada928..1339edc2d 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
})
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
}
func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
@@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- NetworkHeader: buffer.View(ip),
- TransportHeader: buffer.View(u),
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
})
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
}
func TestTransportDemuxerRegister(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 7e8b84867..6c6e44468 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -84,16 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
- hdr.Prepend(fakeTransHeaderLen)
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{
- Header: hdr,
- Data: buffer.View(v).ToVectorisedView(),
- }); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
+ Data: buffer.View(v).ToVectorisedView(),
+ })
+ _ = pkt.TransportHeader().Push(fakeTransHeaderLen)
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
return 0, nil, err
}
@@ -328,13 +328,8 @@ func (*fakeTransportProtocol) Wait() {}
// Parse implements TransportProtocol.Parse.
func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
- hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
- if !ok {
- return false
- }
- pkt.TransportHeader = hdr
- pkt.Data.TrimFront(fakeTransHeaderLen)
- return true
+ _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen)
+ return ok
}
func fakeTransFactory() stack.TransportProtocol {
@@ -382,9 +377,9 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -393,9 +388,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -404,9 +399,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
@@ -459,9 +454,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -470,9 +465,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -481,9 +476,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
- })
+ }))
if fakeTrans.controlCount != 1 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
}
@@ -636,9 +631,9 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: req.ToVectorisedView(),
- })
+ }))
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -655,10 +650,11 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
+ nh := stack.PayloadSince(p.Pkt.NetworkHeader())
+ if dst := nh[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := p.Pkt.NetworkHeader[1]; src != 1 {
+ if src := nh[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}