diff options
author | Googler <noreply@google.com> | 2019-05-03 13:29:20 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2019-05-03 13:30:35 -0700 |
commit | f2699b76c89a5be1ef6411f29a57b4cccc59fa17 (patch) | |
tree | 6e5ec5a4520b98fee3551d0baa16f59db69bc42e /pkg | |
parent | 264d012d81d210c6d949554667c6fbf8e330587a (diff) |
Support IPv4 fragmentation in netstack
Testing:
Unit tests and also large ping in Fuchsia OS
PiperOrigin-RevId: 246563592
Change-Id: Ia12ab619f64f4be2c8d346ce81341a91724aef95
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/buffer/prependable.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/link/sniffer/sniffer.go | 108 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 91 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 270 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 10 |
6 files changed, 435 insertions, 56 deletions
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 43cbbc74c..4287464f3 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -52,6 +52,16 @@ func (p Prependable) UsedLength() int { return len(p.buf) - p.usedIdx } +// AvailableLength returns the number of bytes used so far. +func (p Prependable) AvailableLength() int { + return p.usedIdx +} + +// TrimBack removes size bytes from the end. +func (p *Prependable) TrimBack(size int) { + p.buf = p.buf[:len(p.buf)-size] +} + // Prepend reserves the requested space in front of the buffer, returning a // slice that represents the reserved space. func (p *Prependable) Prepend(size int) []byte { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index e87ae07d7..fccabd554 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -247,9 +247,13 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie dst := tcpip.Address("unknown") id := 0 size := uint16(0) + var fragmentOffset uint16 + var moreFragments bool switch protocol { case header.IPv4ProtocolNumber: ipv4 := header.IPv4(b) + fragmentOffset = ipv4.FragmentOffset() + moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() @@ -290,29 +294,31 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie transName = "icmp" icmp := header.ICMPv4(b) icmpType := "unknown" - switch icmp.Type() { - case header.ICMPv4EchoReply: - icmpType = "echo reply" - case header.ICMPv4DstUnreachable: - icmpType = "destination unreachable" - case header.ICMPv4SrcQuench: - icmpType = "source quench" - case header.ICMPv4Redirect: - icmpType = "redirect" - case header.ICMPv4Echo: - icmpType = "echo" - case header.ICMPv4TimeExceeded: - icmpType = "time exceeded" - case header.ICMPv4ParamProblem: - icmpType = "param problem" - case header.ICMPv4Timestamp: - icmpType = "timestamp" - case header.ICMPv4TimestampReply: - icmpType = "timestamp reply" - case header.ICMPv4InfoRequest: - icmpType = "info request" - case header.ICMPv4InfoReply: - icmpType = "info reply" + if fragmentOffset == 0 { + switch icmp.Type() { + case header.ICMPv4EchoReply: + icmpType = "echo reply" + case header.ICMPv4DstUnreachable: + icmpType = "destination unreachable" + case header.ICMPv4SrcQuench: + icmpType = "source quench" + case header.ICMPv4Redirect: + icmpType = "redirect" + case header.ICMPv4Echo: + icmpType = "echo" + case header.ICMPv4TimeExceeded: + icmpType = "time exceeded" + case header.ICMPv4ParamProblem: + icmpType = "param problem" + case header.ICMPv4Timestamp: + icmpType = "timestamp" + case header.ICMPv4TimestampReply: + icmpType = "timestamp reply" + case header.ICMPv4InfoRequest: + icmpType = "info request" + case header.ICMPv4InfoReply: + icmpType = "info reply" + } } log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) return @@ -351,8 +357,10 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" udp := header.UDP(b) - srcPort = udp.SourcePort() - dstPort = udp.DestinationPort() + if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { + srcPort = udp.SourcePort() + dstPort = udp.DestinationPort() + } size -= header.UDPMinimumSize details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) @@ -360,33 +368,35 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" tcp := header.TCP(b) - offset := int(tcp.DataOffset()) - if offset < header.TCPMinimumSize { - details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) - break - } - if offset > len(tcp) { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) - break - } + if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { + offset := int(tcp.DataOffset()) + if offset < header.TCPMinimumSize { + details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) + break + } + if offset > len(tcp) && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) + break + } - srcPort = tcp.SourcePort() - dstPort = tcp.DestinationPort() - size -= uint16(offset) + srcPort = tcp.SourcePort() + dstPort = tcp.DestinationPort() + size -= uint16(offset) - // Initialize the TCP flags. - flags := tcp.Flags() - flagsStr := []byte("FSRPAU") - for i := range flagsStr { - if flags&(1<<uint(i)) == 0 { - flagsStr[i] = ' ' + // Initialize the TCP flags. + flags := tcp.Flags() + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if flags&(1<<uint(i)) == 0 { + flagsStr[i] = ' ' + } + } + details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) + if flags&header.TCPFlagSyn != 0 { + details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) + } else { + details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions()) } - } - details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) - if flags&header.TCPFlagSyn != 0 { - details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) - } else { - details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions()) } default: diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7a5341def..1b4f29e0c 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -28,11 +28,13 @@ go_test( srcs = ["ipv4_test.go"], deps = [ "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", ], diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c6af0db79..4edc52f19 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -107,6 +107,88 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to +// write. It assumes that the IP header is entirely in hdr but does not assume +// that only the IP header is in hdr. It assumes that the input packet's stated +// length matches the length of the hdr+payload. mtu includes the IP header and +// options. This does not support the DontFragment IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { + // This packet is too big, it needs to be fragmented. + ip := header.IPv4(hdr.View()) + flags := ip.Flags() + + // Update mtu to take into account the header, which will exist in all + // fragments anyway. + innerMTU := mtu - int(ip.HeaderLength()) + + // Round the MTU down to align to 8 bytes. Then calculate the number of + // fragments. Calculate fragment sizes as in RFC791. + innerMTU &^= 7 + n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU + + outerMTU := innerMTU + int(ip.HeaderLength()) + offset := ip.FragmentOffset() + originalAvailableLength := hdr.AvailableLength() + for i := 0; i < n; i++ { + // Where possible, the first fragment that is sent has the same + // hdr.UsedLength() as the input packet. The link-layer endpoint may depends + // on this for looking at, eg, L4 headers. + h := ip + if i > 0 { + hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + copy(h, ip[:ip.HeaderLength()]) + } + if i != n-1 { + h.SetTotalLength(uint16(outerMTU)) + h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) + } else { + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetFlagsFragmentOffset(flags, offset) + } + h.SetChecksum(0) + h.SetChecksum(^h.CalculateChecksum()) + offset += uint16(innerMTU) + if i > 0 { + newPayload := payload.Clone([]buffer.View{}) + newPayload.CapLength(innerMTU) + if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + payload.TrimFront(newPayload.Size()) + continue + } + // Special handling for the first fragment because it comes from the hdr. + if outerMTU >= hdr.UsedLength() { + // This fragment can fit all of hdr and possibly some of payload, too. + newPayload := payload.Clone([]buffer.View{}) + newPayloadLength := outerMTU - hdr.UsedLength() + newPayload.CapLength(newPayloadLength) + if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + payload.TrimFront(newPayloadLength) + } else { + // The fragment is too small to fit all of hdr. + startOfHdr := hdr + startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) + if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + // Add the unused bytes of hdr into the payload that remains to be sent. + restOfHdr := hdr.View()[outerMTU:] + tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) + tmp.Append(payload) + payload = tmp + } + } + return nil +} + // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) @@ -138,9 +220,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen if loop&stack.PacketOut == 0 { return nil } - + if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && gso.Type == stack.GSONone { + return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + } + if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + return err + } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 146143ab3..7a09ef6de 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,14 +15,19 @@ package ipv4_test import ( + "bytes" + "encoding/hex" + "math/rand" "testing" "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -90,3 +95,268 @@ func TestExcludeBroadcast(t *testing.T) { } }) } + +// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// data should already be in the header before WritePacket. extraLength +// indicates how much extra space should be in the header. The payload is made +// from many Views of the sizes listed in viewSizes. +func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { + hdr := buffer.NewPrependable(hdrLength + extraLength) + hdr.Prepend(hdrLength) + rand.Read(hdr.View()) + + var views []buffer.View + totalLength := 0 + for _, s := range viewSizes { + newView := buffer.NewView(s) + rand.Read(newView) + views = append(views, newView) + totalLength += s + } + payload := buffer.NewVectorisedView(totalLength, views) + return hdr, payload +} + +// comparePayloads compared the contents of all the packets against the contents +// of the source packet. +func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { + t.Helper() + // Make a complete array of the sourcePacketInfo packet. + source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) + source = append(source, sourcePacketInfo.Header.View()...) + source = append(source, sourcePacketInfo.Payload.ToView()...) + + // Make a copy of the IP header, which will be modified in some fields to make + // an expected header. + sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) + sourceCopy.SetChecksum(0) + sourceCopy.SetFlagsFragmentOffset(0, 0) + sourceCopy.SetTotalLength(0) + var offset uint16 + // Build up an array of the bytes sent. + var reassembledPayload []byte + for i, packet := range packets { + // Confirm that the packet is valid. + allBytes := packet.Header.View().ToVectorisedView() + allBytes.Append(packet.Payload) + ip := header.IPv4(allBytes.ToView()) + if !ip.IsValid(len(ip)) { + t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + } + if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { + t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + } + if got, want := len(ip), int(mtu); got > want { + t.Errorf("fragment is too large, got %d want %d", got, want) + } + if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } + if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + } + if i < len(packets)-1 { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) + } else { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) + } + reassembledPayload = append(reassembledPayload, ip.Payload()...) + offset += ip.TotalLength() - uint16(ip.HeaderLength()) + // Clear out the checksum and length from the ip because we can't compare + // it. + sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetChecksum(0) + sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) + if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { + t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + } + } + expected := source[source.HeaderLength():] + if !bytes.Equal(reassembledPayload, expected) { + t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + } +} + +type errorChannel struct { + *channel.Endpoint + Ch chan packetInfo + packetCollectorErrors []*tcpip.Error +} + +// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket +// will return successive errors from packetCollectorErrors until the list is +// empty and then return nil each time. +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { + _, e := channel.New(size, mtu, linkAddr) + ec := errorChannel{ + Endpoint: e, + Ch: make(chan packetInfo, size), + packetCollectorErrors: packetCollectorErrors, + } + + return stack.RegisterLinkEndpoint(e), &ec +} + +// packetInfo holds all the information about an outbound packet. +type packetInfo struct { + Header buffer.Prependable + Payload buffer.VectorisedView +} + +// Drain removes all outbound packets from the channel and counts them. +func (e *errorChannel) Drain() int { + c := 0 + for { + select { + case <-e.Ch: + c++ + default: + return c + } + } +} + +// WritePacket stores outbound packets into the channel. +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + p := packetInfo{ + Header: hdr, + Payload: payload, + } + + select { + case e.Ch <- p: + default: + } + + nextError := (*tcpip.Error)(nil) + if len(e.packetCollectorErrors) > 0 { + nextError = e.packetCollectorErrors[0] + e.packetCollectorErrors = e.packetCollectorErrors[1:] + } + return nextError +} + +type context struct { + stack.Route + linkEP *errorChannel +} + +func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { + // Make the packet and write it. + s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) + _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + linkEPId := stack.RegisterLinkEndpoint(linkEP) + s.CreateNIC(1, linkEPId) + s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01") + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x10\x00\x00\x02", + Mask: "\xff\xff\xff\xff", + Gateway: "", + NIC: 1, + }}) + r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return context{ + Route: r, + linkEP: linkEP, + } +} + +func TestFragmentation(t *testing.T) { + var manyPayloadViewsSizes [1000]int + for i := range manyPayloadViewsSizes { + manyPayloadViewsSizes[i] = 7 + } + fragTests := []struct { + description string + mtu uint32 + hdrLength int + extraLength int + payloadViewsSizes []int + expectedFrags int + }{ + {"NoFragmentation", 2000, 0, header.IPv4MinimumSize, []int{1000}, 1}, + {"NoFragmentationWithBigHeader", 2000, 16, header.IPv4MinimumSize, []int{1000}, 1}, + {"Fragmented", 800, 0, header.IPv4MinimumSize, []int{1000}, 2}, + {"FragmentedWithManyViews", 300, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, + {"FragmentedWithManyViewsAndPrependableBytes", 300, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, + {"FragmentedWithBigHeader", 800, 20, header.IPv4MinimumSize, []int{1000}, 2}, + {"FragmentedWithBigHeaderAndPrependableBytes", 800, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, + {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, + } + + for _, ft := range fragTests { + t.Run(ft.description, func(t *testing.T) { + hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := packetInfo{ + Header: hdr, + // Save the source payload because WritePacket will modify it. + Payload: payload.Clone([]buffer.View{}), + } + c := buildContext(t, nil, ft.mtu) + err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + if err != nil { + t.Errorf("err got %v, want %v", err, nil) + } + + var results []packetInfo + L: + for { + select { + case pi := <-c.linkEP.Ch: + results = append(results, pi) + default: + break L + } + } + + if got, want := len(results), ft.expectedFrags; got != want { + t.Errorf("len(result) got %d, want %d", got, want) + } + if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet len(result) got %d, want %d", got, want) + } + compareFragments(t, results, source, ft.mtu) + }) + } +} + +// TestFragmentationErrors checks that errors are returned from write packet +// correctly. +func TestFragmentationErrors(t *testing.T) { + fragTests := []struct { + description string + mtu uint32 + hdrLength int + payloadViewsSizes []int + packetCollectorErrors []*tcpip.Error + }{ + {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, + {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, + {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, + {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + } + + for _, ft := range fragTests { + t.Run(ft.description, func(t *testing.T) { + hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + c := buildContext(t, ft.packetCollectorErrors, ft.mtu) + err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { + if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { + t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) + } + } + // We only need to check that last error because all the ones before are + // nil. + if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { + t.Errorf("err got %v, want %v", got, want) + } + if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { + t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6e3ba5922..e341bb4aa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2381,7 +2381,7 @@ func TestFinWithPartialAck(t *testing.T) { } func TestExponentialIncreaseDuringSlowStart(t *testing.T) { - maxPayload := 10 + maxPayload := 32 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() @@ -2423,7 +2423,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { } func TestCongestionAvoidance(t *testing.T) { - maxPayload := 10 + maxPayload := 32 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() @@ -2525,7 +2525,7 @@ func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Durati } func TestCubicCongestionAvoidance(t *testing.T) { - maxPayload := 10 + maxPayload := 32 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() @@ -2636,7 +2636,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { } func TestFastRecovery(t *testing.T) { - maxPayload := 10 + maxPayload := 32 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() @@ -2788,7 +2788,7 @@ func TestFastRecovery(t *testing.T) { } func TestRetransmit(t *testing.T) { - maxPayload := 10 + maxPayload := 32 c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) defer c.Cleanup() |