summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorArthur Sfez <asfez@google.com>2020-10-06 13:55:02 -0700
committergVisor bot <gvisor-bot@google.com>2020-10-06 14:03:39 -0700
commit99bf022c2aeff35e48d9201406f85f501405c083 (patch)
tree45c122790e181764524705a68ca2d0ab01ae7a63
parentb761330ca7bd5e956e3b90b4202ff0c3ea892750 (diff)
Add support for IPv6 fragmentation
Most of the IPv4 fragmentation code was moved in the fragmentation package and it is reused by IPv6 fragmentation. Test: - pkg/tcpip/network/ipv4:ipv4_test - pkg/tcpip/network/ipv6:ipv6_test - pkg/tcpip/network/fragmentation:fragmentation_test Fixes #4389 PiperOrigin-RevId: 335714280
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go78
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go112
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go168
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go78
-rw-r--r--pkg/tcpip/network/ipv6/BUILD2
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go160
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go386
-rw-r--r--pkg/tcpip/network/testutil/BUILD1
-rw-r--r--test/packetimpact/tests/BUILD2
11 files changed, 836 insertions, 159 deletions
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index e247f06a4..47fb63290 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -29,6 +29,8 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
],
)
@@ -44,5 +46,7 @@ go_test(
deps = [
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/testutil",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index e1909fab0..888ad62a3 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -13,7 +13,7 @@
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
-// It is based on RFC 791 and RFC 815.
+// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -243,3 +244,78 @@ func (f *Fragmentation) releaseReassemblersLocked() {
f.release(r)
}
}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ innerMTU int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ innerMTU: innerMTU,
+ fragmentCount: fragmentCount,
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 189b223c5..31a1eb862 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -20,8 +20,10 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
)
// vv is a helper to build VectorisedView from different strings.
@@ -381,3 +383,113 @@ func TestErrors(t *testing.T) {
})
}
}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ innerMTU int
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ innerMTU: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ innerMTU: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ innerMTU: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ innerMTU: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := fragPkt.Size(); got > test.innerMTU {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index a2be64fb8..79c939129 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -190,99 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is already present in pkt.NetworkHeader.
-// pkt.TransportHeader may be set. mtu includes the IP header and options. This
-// does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
- // This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.NetworkHeader().View())
- flags := ip.Flags()
-
- // Update mtu to take into account the header, which will exist in all
- // fragments anyway.
- innerMTU := mtu - int(ip.HeaderLength())
-
- // Round the MTU down to align to 8 bytes. Then calculate the number of
- // fragments. Calculate fragment sizes as in RFC791.
- innerMTU &^= 7
- n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
-
- outerMTU := innerMTU + int(ip.HeaderLength())
- offset := ip.FragmentOffset()
-
- // Keep the length reserved for link-layer, we need to create fragments with
- // the same reserved length.
- reservedForLink := pkt.AvailableHeaderBytes()
-
- // Destroy the packet, pull all payloads out for fragmentation.
- transHeader, data := pkt.TransportHeader().View(), pkt.Data
-
- // Where possible, the first fragment that is sent has the same
- // number of bytes reserved for header as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- transFitsFirst := len(transHeader) <= innerMTU
-
- for i := 0; i < n; i++ {
- reserve := reservedForLink + int(ip.HeaderLength())
- if i == 0 && transFitsFirst {
- // Reserve for transport header if it's going to be put in the first
- // fragment.
- reserve += len(transHeader)
- }
- fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: reserve,
- })
- fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
-
- // Copy data for the fragment.
- avail := innerMTU
-
- if n := len(transHeader); n > 0 {
- if n > avail {
- n = avail
- }
- if i == 0 && transFitsFirst {
- copy(fragPkt.TransportHeader().Push(n), transHeader)
- } else {
- fragPkt.Data.AppendView(transHeader[:n:n])
- }
- transHeader = transHeader[n:]
- avail -= n
- }
-
- if avail > 0 {
- n := data.Size()
- if n > avail {
- n = avail
- }
- data.ReadToVV(&fragPkt.Data, n)
- avail -= n
- }
-
- copied := uint16(innerMTU - avail)
-
- // Set lengths in header and calculate checksum.
- h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
- copy(h, ip)
- if i != n-1 {
- h.SetTotalLength(uint16(outerMTU))
- h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
- } else {
- h.SetTotalLength(uint16(h.HeaderLength()) + copied)
- h.SetFlagsFragmentOffset(flags, offset)
- }
- h.SetChecksum(0)
- h.SetChecksum(^h.CalculateChecksum())
- offset += copied
-
- // Send out the fragment.
+// writePacketFragments fragments pkt and writes the results on the link
+// endpoint. The IP header must already present in the original packet. The mtu
+// is the maximum size of the packets.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
+
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(n - i))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
return err
}
r.Stats().IP.PacketsSent.Increment()
+ if !more {
+ break
+ }
}
+
return nil
}
@@ -304,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -330,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -347,7 +274,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ return e.writePacketFragments(r, gso, e.linkEP.MTU(), pkt)
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -397,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -809,14 +736,36 @@ func calculateMTU(mtu uint32) uint32 {
return mtu - header.IPv4MinimumSize
}
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ return mtu
+}
+
+// addressToUint32 translates an IPv4 address into its little endian uint32
+// representation.
+//
+// This function does the same thing as binary.LittleEndian.Uint32 but operates
+// on a tcpip.Address (a string) without the need to convert it to a byte slice,
+// which would cause an allocation.
+func addressToUint32(addr tcpip.Address) uint32 {
+ _ = addr[3] // bounds check hint to compiler
+ return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
+}
+
// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number, and a random initial
-// value (generated once on initialization) to generate the hash.
+// destination address, the transport protocol number and a 32-bit number to
+// generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- t := r.LocalAddress
- a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- t = r.RemoteAddress
- b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ a := addressToUint32(r.LocalAddress)
+ b := addressToUint32(r.RemoteAddress)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
@@ -839,3 +788,26 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
}
}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeaderLength := len(originalIPHeader)
+ nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+
+ if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
+ }
+
+ flags := originalIPHeader.Flags()
+ if more {
+ flags |= header.IPv4FlagMoreFragments
+ }
+ nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset))
+ nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied))
+ nextFragIPHeader.SetChecksum(0)
+ nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum())
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 712fbb861..f250a3cde 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -396,16 +396,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if i == 0 {
- got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
- // sourcePacketInfo does not have NetworkHeader added, simulate one.
- want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
- // Check that it kept the transport header in packet.TransportHeader if
- // it fits in the first fragment.
- if want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
- }
- }
if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
@@ -435,6 +425,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
func TestFragmentation(t *testing.T) {
+ const ttl = 42
+
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
@@ -448,15 +440,15 @@ func TestFragmentation(t *testing.T) {
payloadViewsSizes []int
expectedFrags int
}{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
{"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
}
for _, ft := range fragTests {
@@ -467,11 +459,11 @@ func TestFragmentation(t *testing.T) {
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("got err = %s, want = nil", err)
+ t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
if got := len(ep.WrittenPackets); got != ft.expectedFrags {
@@ -491,48 +483,46 @@ func TestFragmentation(t *testing.T) {
// TestFragmentationErrors checks that errors are returned from write packet
// correctly.
func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ expectedError := tcpip.ErrAborted
fragTests := []struct {
description string
mtu uint32
transportHeaderLength int
- payloadViewsSizes []int
- err *tcpip.Error
+ payloadSize int
allowPackets int
fragmentCount int
}{
{
- description: "NoFrag",
+ description: "No frag",
mtu: 2000,
transportHeaderLength: 0,
- payloadViewsSizes: []int{1000},
- err: tcpip.ErrAborted,
+ payloadSize: 1000,
allowPackets: 0,
fragmentCount: 1,
},
{
- description: "ErrorOnFirstFrag",
+ description: "Error on first frag",
mtu: 500,
transportHeaderLength: 0,
- payloadViewsSizes: []int{1000},
- err: tcpip.ErrAborted,
+ payloadSize: 1000,
allowPackets: 0,
fragmentCount: 3,
},
{
- description: "ErrorOnSecondFrag",
+ description: "Error on second frag",
mtu: 500,
transportHeaderLength: 0,
- payloadViewsSizes: []int{1000},
- err: tcpip.ErrAborted,
+ payloadSize: 1000,
allowPackets: 1,
fragmentCount: 3,
},
{
- description: "ErrorOnFirstFragMTUSmallerThanHeader",
+ description: "Error on first frag MTU smaller than header",
mtu: 500,
transportHeaderLength: 1000,
- payloadViewsSizes: []int{500},
- err: tcpip.ErrAborted,
+ payloadSize: 500,
allowPackets: 0,
fragmentCount: 4,
},
@@ -540,16 +530,16 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.err {
- t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
+ if err != expectedError {
+ t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
}
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
@@ -1317,6 +1307,7 @@ func TestReceiveFragments(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
+
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
@@ -1462,12 +1453,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\x10\x00\x00\x02"
)
if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ mask := tcpip.AddressMask(header.IPv4Broadcast)
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1476,7 +1468,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
}
return rt
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 97adbcbd4..a30437f02 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -41,6 +42,7 @@ go_test(
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 31370c1d4..3a40b790e 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -50,6 +50,10 @@ type stubLinkEndpoint struct {
stack.LinkEndpoint
}
+func (*stubLinkEndpoint) MTU() uint32 {
+ return defaultMTU
+}
+
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
// Indicate that resolution for link layer addresses is required to send
// packets over this link. This is needed so the NIC knows to allocate a
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c8a3e0b34..73e50f8d6 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -16,7 +16,9 @@
package ipv6
import (
+ "encoding/binary"
"fmt"
+ "hash/fnv"
"sort"
"sync/atomic"
@@ -26,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -40,6 +43,9 @@ const (
// DefaultTTL is the default hop limit for IPv6 Packets egressed by
// Netstack.
DefaultTTL = 64
+
+ // buckets for fragment identifiers
+ buckets = 2048
)
var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
@@ -376,7 +382,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
+ return pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The mtu is the maximum size of the packets. The transport
+// header protocol number is required to avoid parsing the IPv6 extension
+// headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ if fragMTU < pkt.TransportHeader().View().Size() {
+ // As per RFC 8200 Section 4.5, the Transport Header is expected to be small
+ // enough to fit in the first fragment.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ break
+ }
+ }
+
+ return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -402,7 +445,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -423,15 +466,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
+ if e.packetMustBeFragmented(pkt, gso) {
+ sent, remain, err := e.handleFragments(r, gso, e.linkEP.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
+ }
+
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
+
r.Stats().IP.PacketsSent.Increment()
return nil
}
-// WritePackets implements stack.LinkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
@@ -442,6 +499,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
+ if e.packetMustBeFragmented(pb, gso) {
+ current := pb
+ _, _, err := e.handleFragments(r, gso, e.linkEP.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(current, fragPkt)
+ current = current.Next()
+ return nil
+ })
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ // The fragmented packet can be released. The rest of the packets can be
+ // processed.
+ pkts.Remove(pb)
+ pb = current
+ }
}
// iptables filtering. All packets that reach here are locally
@@ -470,7 +544,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -1155,6 +1229,9 @@ type protocol struct {
eps map[*endpoint]struct{}
}
+ ids []uint32
+ hashIV uint32
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -1376,10 +1453,15 @@ type Options struct {
func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
opts.NDPConfigs.validate()
+ ids := hash.RandN32(buckets)
+ hashIV := hash.RandN32(1)[0]
+
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
stack: s,
fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout, s.Clock()),
+ ids: ids,
+ hashIV: hashIV,
ndpDisp: opts.NDPDisp,
ndpConfigs: opts.NDPConfigs,
@@ -1397,3 +1479,73 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return NewProtocolWithOptions(Options{})(s)
}
+
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // MTU because they should only be transmitted once.
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ mtu -= header.IPv6FragmentHeaderSize
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
+ return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address and 32-bit number to generate the hash.
+func hashRoute(r *stack.Route, hashIV uint32) uint32 {
+ // The FNV-1a was chosen because it is a fast hashing algorithm, and
+ // cryptographic properties are not needed here.
+ h := fnv.New32a()
+ if _, err := h.Write([]byte(r.LocalAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ if _, err := h.Write([]byte(r.RemoteAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ s := make([]byte, 4)
+ binary.LittleEndian.PutUint32(s, hashIV)
+ if _, err := h.Write(s); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err))
+ }
+
+ return h.Sum32()
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeadersLength := len(originalIPHeaders)
+ fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+ fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+
+ // Copy the IPv6 header and any extension headers already populated.
+ if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
+ }
+ fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
+
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
+ fragmentHeader.Encode(&header.IPv6FragmentFields{
+ M: more,
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ Identification: id,
+ NextHeader: uint8(transportProto),
+ })
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index d7f82973b..e792ca9e2 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "encoding/hex"
+ "fmt"
"math"
"testing"
@@ -27,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -137,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
}
}
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // sourcePacket does not have its IP Header populated. Let's copy the one
+ // from the first fragment.
+ source := header.IPv6(packets[0].NetworkHeader().View())
+ sourceIPHeadersLen := len(source)
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
+
+ var reassembledPayload buffer.VectorisedView
+ for i, fragment := range packets {
+ // Confirm that the packet is valid.
+ allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
+ fragmentIPHeaders := header.IPv6(allBytes.ToView())
+ if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
+ }
+
+ fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
+ if fragmentIPHeadersLength != sourceIPHeadersLen {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
+ }
+
+ if got := len(fragmentIPHeaders); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
+ }
+
+ sourceIPHeader := source[:header.IPv6MinimumSize]
+ fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
+
+ if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
+ }
+
+ // We expect the IPv6 Header to be similar across each fragment, besides the
+ // payload length.
+ sourceIPHeader.SetPayloadLength(0)
+ fragmentIPHeader.SetPayloadLength(0)
+ if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
+ }
+
+ if len(packets) > 1 {
+ // If the source packet was big enough that it needed fragmentation, let's
+ // inspect the fragment header. Because no other extension headers are
+ // supported, it will always be the last extension header.
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
+
+ if got := fragmentHeader.More(); got != wantFragments[i].more {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
+ }
+ if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
+ }
+ if got := fragmentHeader.NextHeader(); got != uint8(proto) {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
+ }
+ }
+
+ // Store the reassembled payload as we parse each fragment. The payload
+ // includes the Transport header and everything after.
+ reassembledPayload.AppendView(fragment.TransportHeader().View())
+ reassembledPayload.Append(fragment.Data)
+ }
+
+ result := reassembledPayload.ToView()
+ if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+
+ return nil
+}
+
// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
@@ -171,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
protocolFactory stack.TransportProtocolFactory
@@ -196,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -296,8 +373,6 @@ func TestAddIpv6Address(t *testing.T) {
}
func TestReceiveIPv6ExtHdrs(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
@@ -897,7 +972,6 @@ type fragmentData struct {
func TestReceiveIPv6Fragments(t *testing.T) {
const (
- nicID = 1
udpPayload1Length = 256
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
@@ -2029,7 +2103,6 @@ func TestWriteStats(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
-
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -2071,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -2085,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
}
return rt
}
@@ -2136,3 +2210,293 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
}
}
}
+
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
+
+type fragmentationTestCase struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transHdrLen int
+ extraHdrLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ expectedFrags int
+}
+
+var fragmentationTests = []fragmentationTestCase{
+ {
+ description: "No Fragmentation",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso nil",
+ mtu: 1280,
+ gso: nil,
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 176, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 76, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header and prependable bytes",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 20,
+ extraHdrLen: header.IPv6MinimumSize + 66,
+ payloadSize: 1500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 296, more: false},
+ },
+ },
+}
+
+func TestFragmentation(t *testing.T) {
+ const (
+ ttl = 42
+ tos = stack.DefaultTOS
+ transportProto = tcp.ProtocolNumber
+ )
+
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt.Clone()
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("WritePacket(_, _, _): = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if len(ep.WrittenPackets) > 0 {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ tests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if n != wantTotalPackets || err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ tests := []struct {
+ description string
+ mtu uint32
+ transHdrLen int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
+ }{
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 1300,
+ payloadSize: 3000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 1500,
+ payloadSize: 4000,
+ transHdrLen: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on packet with MTU smaller than transport header",
+ mtu: 1280,
+ transHdrLen: 1500,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrMessageTooLong,
+ },
+ }
+
+ for _, ft := range tests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index c9e57dc0d..d0ffc299a 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -8,6 +8,7 @@ go_library(
"testutil.go",
],
visibility = [
+ "//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
],
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 94731c64b..11db49e39 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -310,8 +310,6 @@ packetimpact_go_test(
packetimpact_go_test(
name = "ipv6_fragment_reassembly",
srcs = ["ipv6_fragment_reassembly_test.go"],
- # TODO(b/160919104): Fix netstack then remove the line below.
- expect_netstack_failure = True,
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",