summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/buffer/prependable.go10
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go108
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go91
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go270
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go10
6 files changed, 435 insertions, 56 deletions
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
index 43cbbc74c..4287464f3 100644
--- a/pkg/tcpip/buffer/prependable.go
+++ b/pkg/tcpip/buffer/prependable.go
@@ -52,6 +52,16 @@ func (p Prependable) UsedLength() int {
return len(p.buf) - p.usedIdx
}
+// AvailableLength returns the number of bytes used so far.
+func (p Prependable) AvailableLength() int {
+ return p.usedIdx
+}
+
+// TrimBack removes size bytes from the end.
+func (p *Prependable) TrimBack(size int) {
+ p.buf = p.buf[:len(p.buf)-size]
+}
+
// Prepend reserves the requested space in front of the buffer, returning a
// slice that represents the reserved space.
func (p *Prependable) Prepend(size int) []byte {
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index e87ae07d7..fccabd554 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -247,9 +247,13 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
dst := tcpip.Address("unknown")
id := 0
size := uint16(0)
+ var fragmentOffset uint16
+ var moreFragments bool
switch protocol {
case header.IPv4ProtocolNumber:
ipv4 := header.IPv4(b)
+ fragmentOffset = ipv4.FragmentOffset()
+ moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
@@ -290,29 +294,31 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
transName = "icmp"
icmp := header.ICMPv4(b)
icmpType := "unknown"
- switch icmp.Type() {
- case header.ICMPv4EchoReply:
- icmpType = "echo reply"
- case header.ICMPv4DstUnreachable:
- icmpType = "destination unreachable"
- case header.ICMPv4SrcQuench:
- icmpType = "source quench"
- case header.ICMPv4Redirect:
- icmpType = "redirect"
- case header.ICMPv4Echo:
- icmpType = "echo"
- case header.ICMPv4TimeExceeded:
- icmpType = "time exceeded"
- case header.ICMPv4ParamProblem:
- icmpType = "param problem"
- case header.ICMPv4Timestamp:
- icmpType = "timestamp"
- case header.ICMPv4TimestampReply:
- icmpType = "timestamp reply"
- case header.ICMPv4InfoRequest:
- icmpType = "info request"
- case header.ICMPv4InfoReply:
- icmpType = "info reply"
+ if fragmentOffset == 0 {
+ switch icmp.Type() {
+ case header.ICMPv4EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv4DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv4SrcQuench:
+ icmpType = "source quench"
+ case header.ICMPv4Redirect:
+ icmpType = "redirect"
+ case header.ICMPv4Echo:
+ icmpType = "echo"
+ case header.ICMPv4TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv4ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv4Timestamp:
+ icmpType = "timestamp"
+ case header.ICMPv4TimestampReply:
+ icmpType = "timestamp reply"
+ case header.ICMPv4InfoRequest:
+ icmpType = "info request"
+ case header.ICMPv4InfoReply:
+ icmpType = "info reply"
+ }
}
log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
@@ -351,8 +357,10 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.UDPProtocolNumber:
transName = "udp"
udp := header.UDP(b)
- srcPort = udp.SourcePort()
- dstPort = udp.DestinationPort()
+ if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
+ srcPort = udp.SourcePort()
+ dstPort = udp.DestinationPort()
+ }
size -= header.UDPMinimumSize
details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
@@ -360,33 +368,35 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.TCPProtocolNumber:
transName = "tcp"
tcp := header.TCP(b)
- offset := int(tcp.DataOffset())
- if offset < header.TCPMinimumSize {
- details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
- break
- }
- if offset > len(tcp) {
- details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
- break
- }
+ if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
+ offset := int(tcp.DataOffset())
+ if offset < header.TCPMinimumSize {
+ details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
+ break
+ }
+ if offset > len(tcp) && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
+ break
+ }
- srcPort = tcp.SourcePort()
- dstPort = tcp.DestinationPort()
- size -= uint16(offset)
+ srcPort = tcp.SourcePort()
+ dstPort = tcp.DestinationPort()
+ size -= uint16(offset)
- // Initialize the TCP flags.
- flags := tcp.Flags()
- flagsStr := []byte("FSRPAU")
- for i := range flagsStr {
- if flags&(1<<uint(i)) == 0 {
- flagsStr[i] = ' '
+ // Initialize the TCP flags.
+ flags := tcp.Flags()
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if flags&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ if flags&header.TCPFlagSyn != 0 {
+ details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
+ } else {
+ details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
}
- }
- details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
- if flags&header.TCPFlagSyn != 0 {
- details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
- } else {
- details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
}
default:
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7a5341def..1b4f29e0c 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -28,11 +28,13 @@ go_test(
srcs = ["ipv4_test.go"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c6af0db79..4edc52f19 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -107,6 +107,88 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
+// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
+// write. It assumes that the IP header is entirely in hdr but does not assume
+// that only the IP header is in hdr. It assumes that the input packet's stated
+// length matches the length of the hdr+payload. mtu includes the IP header and
+// options. This does not support the DontFragment IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+ // This packet is too big, it needs to be fragmented.
+ ip := header.IPv4(hdr.View())
+ flags := ip.Flags()
+
+ // Update mtu to take into account the header, which will exist in all
+ // fragments anyway.
+ innerMTU := mtu - int(ip.HeaderLength())
+
+ // Round the MTU down to align to 8 bytes. Then calculate the number of
+ // fragments. Calculate fragment sizes as in RFC791.
+ innerMTU &^= 7
+ n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
+
+ outerMTU := innerMTU + int(ip.HeaderLength())
+ offset := ip.FragmentOffset()
+ originalAvailableLength := hdr.AvailableLength()
+ for i := 0; i < n; i++ {
+ // Where possible, the first fragment that is sent has the same
+ // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
+ // on this for looking at, eg, L4 headers.
+ h := ip
+ if i > 0 {
+ hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
+ copy(h, ip[:ip.HeaderLength()])
+ }
+ if i != n-1 {
+ h.SetTotalLength(uint16(outerMTU))
+ h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
+ } else {
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetFlagsFragmentOffset(flags, offset)
+ }
+ h.SetChecksum(0)
+ h.SetChecksum(^h.CalculateChecksum())
+ offset += uint16(innerMTU)
+ if i > 0 {
+ newPayload := payload.Clone([]buffer.View{})
+ newPayload.CapLength(innerMTU)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayload.Size())
+ continue
+ }
+ // Special handling for the first fragment because it comes from the hdr.
+ if outerMTU >= hdr.UsedLength() {
+ // This fragment can fit all of hdr and possibly some of payload, too.
+ newPayload := payload.Clone([]buffer.View{})
+ newPayloadLength := outerMTU - hdr.UsedLength()
+ newPayload.CapLength(newPayloadLength)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayloadLength)
+ } else {
+ // The fragment is too small to fit all of hdr.
+ startOfHdr := hdr
+ startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
+ emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
+ if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ // Add the unused bytes of hdr into the payload that remains to be sent.
+ restOfHdr := hdr.View()[outerMTU:]
+ tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
+ tmp.Append(payload)
+ payload = tmp
+ }
+ }
+ return nil
+}
+
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
@@ -138,9 +220,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if loop&stack.PacketOut == 0 {
return nil
}
-
+ if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && gso.Type == stack.GSONone {
+ return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ }
+ if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 146143ab3..7a09ef6de 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,14 +15,19 @@
package ipv4_test
import (
+ "bytes"
+ "encoding/hex"
+ "math/rand"
"testing"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -90,3 +95,268 @@ func TestExcludeBroadcast(t *testing.T) {
}
})
}
+
+// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
+// data should already be in the header before WritePacket. extraLength
+// indicates how much extra space should be in the header. The payload is made
+// from many Views of the sizes listed in viewSizes.
+func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
+ hdr := buffer.NewPrependable(hdrLength + extraLength)
+ hdr.Prepend(hdrLength)
+ rand.Read(hdr.View())
+
+ var views []buffer.View
+ totalLength := 0
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ rand.Read(newView)
+ views = append(views, newView)
+ totalLength += s
+ }
+ payload := buffer.NewVectorisedView(totalLength, views)
+ return hdr, payload
+}
+
+// comparePayloads compared the contents of all the packets against the contents
+// of the source packet.
+func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
+ t.Helper()
+ // Make a complete array of the sourcePacketInfo packet.
+ source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
+ source = append(source, sourcePacketInfo.Header.View()...)
+ source = append(source, sourcePacketInfo.Payload.ToView()...)
+
+ // Make a copy of the IP header, which will be modified in some fields to make
+ // an expected header.
+ sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetFlagsFragmentOffset(0, 0)
+ sourceCopy.SetTotalLength(0)
+ var offset uint16
+ // Build up an array of the bytes sent.
+ var reassembledPayload []byte
+ for i, packet := range packets {
+ // Confirm that the packet is valid.
+ allBytes := packet.Header.View().ToVectorisedView()
+ allBytes.Append(packet.Payload)
+ ip := header.IPv4(allBytes.ToView())
+ if !ip.IsValid(len(ip)) {
+ t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ }
+ if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
+ t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ }
+ if got, want := len(ip), int(mtu); got > want {
+ t.Errorf("fragment is too large, got %d want %d", got, want)
+ }
+ if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
+ t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ }
+ if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
+ t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ }
+ if i < len(packets)-1 {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ } else {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ }
+ reassembledPayload = append(reassembledPayload, ip.Payload()...)
+ offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ // Clear out the checksum and length from the ip because we can't compare
+ // it.
+ sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+ if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
+ t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ }
+ }
+ expected := source[source.HeaderLength():]
+ if !bytes.Equal(reassembledPayload, expected) {
+ t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+ }
+}
+
+type errorChannel struct {
+ *channel.Endpoint
+ Ch chan packetInfo
+ packetCollectorErrors []*tcpip.Error
+}
+
+// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
+// will return successive errors from packetCollectorErrors until the list is
+// empty and then return nil each time.
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
+ _, e := channel.New(size, mtu, linkAddr)
+ ec := errorChannel{
+ Endpoint: e,
+ Ch: make(chan packetInfo, size),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+
+ return stack.RegisterLinkEndpoint(e), &ec
+}
+
+// packetInfo holds all the information about an outbound packet.
+type packetInfo struct {
+ Header buffer.Prependable
+ Payload buffer.VectorisedView
+}
+
+// Drain removes all outbound packets from the channel and counts them.
+func (e *errorChannel) Drain() int {
+ c := 0
+ for {
+ select {
+ case <-e.Ch:
+ c++
+ default:
+ return c
+ }
+ }
+}
+
+// WritePacket stores outbound packets into the channel.
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ p := packetInfo{
+ Header: hdr,
+ Payload: payload,
+ }
+
+ select {
+ case e.Ch <- p:
+ default:
+ }
+
+ nextError := (*tcpip.Error)(nil)
+ if len(e.packetCollectorErrors) > 0 {
+ nextError = e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ }
+ return nextError
+}
+
+type context struct {
+ stack.Route
+ linkEP *errorChannel
+}
+
+func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+ // Make the packet and write it.
+ s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
+ _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ linkEPId := stack.RegisterLinkEndpoint(linkEP)
+ s.CreateNIC(1, linkEPId)
+ s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01")
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x10\x00\x00\x02",
+ Mask: "\xff\xff\xff\xff",
+ Gateway: "",
+ NIC: 1,
+ }})
+ r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return context{
+ Route: r,
+ linkEP: linkEP,
+ }
+}
+
+func TestFragmentation(t *testing.T) {
+ var manyPayloadViewsSizes [1000]int
+ for i := range manyPayloadViewsSizes {
+ manyPayloadViewsSizes[i] = 7
+ }
+ fragTests := []struct {
+ description string
+ mtu uint32
+ hdrLength int
+ extraLength int
+ payloadViewsSizes []int
+ expectedFrags int
+ }{
+ {"NoFragmentation", 2000, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"NoFragmentationWithBigHeader", 2000, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"Fragmented", 800, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithManyViews", 300, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithManyViewsAndPrependableBytes", 300, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithBigHeader", 800, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithBigHeaderAndPrependableBytes", 800, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ source := packetInfo{
+ Header: hdr,
+ // Save the source payload because WritePacket will modify it.
+ Payload: payload.Clone([]buffer.View{}),
+ }
+ c := buildContext(t, nil, ft.mtu)
+ err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ if err != nil {
+ t.Errorf("err got %v, want %v", err, nil)
+ }
+
+ var results []packetInfo
+ L:
+ for {
+ select {
+ case pi := <-c.linkEP.Ch:
+ results = append(results, pi)
+ default:
+ break L
+ }
+ }
+
+ if got, want := len(results), ft.expectedFrags; got != want {
+ t.Errorf("len(result) got %d, want %d", got, want)
+ }
+ if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ }
+ compareFragments(t, results, source, ft.mtu)
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from write packet
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ fragTests := []struct {
+ description string
+ mtu uint32
+ hdrLength int
+ payloadViewsSizes []int
+ packetCollectorErrors []*tcpip.Error
+ }{
+ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
+ c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
+ err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
+ if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
+ t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
+ }
+ }
+ // We only need to check that last error because all the ones before are
+ // nil.
+ if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
+ t.Errorf("err got %v, want %v", got, want)
+ }
+ if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6e3ba5922..e341bb4aa 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2381,7 +2381,7 @@ func TestFinWithPartialAck(t *testing.T) {
}
func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
- maxPayload := 10
+ maxPayload := 32
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
@@ -2423,7 +2423,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
}
func TestCongestionAvoidance(t *testing.T) {
- maxPayload := 10
+ maxPayload := 32
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
@@ -2525,7 +2525,7 @@ func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Durati
}
func TestCubicCongestionAvoidance(t *testing.T) {
- maxPayload := 10
+ maxPayload := 32
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
@@ -2636,7 +2636,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
}
func TestFastRecovery(t *testing.T) {
- maxPayload := 10
+ maxPayload := 32
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
@@ -2788,7 +2788,7 @@ func TestFastRecovery(t *testing.T) {
}
func TestRetransmit(t *testing.T) {
- maxPayload := 10
+ maxPayload := 32
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()