summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2019-05-03 13:29:20 -0700
committerShentubot <shentubot@google.com>2019-05-03 13:30:35 -0700
commitf2699b76c89a5be1ef6411f29a57b4cccc59fa17 (patch)
tree6e5ec5a4520b98fee3551d0baa16f59db69bc42e /pkg/tcpip/network/ipv4
parent264d012d81d210c6d949554667c6fbf8e330587a (diff)
Support IPv4 fragmentation in netstack
Testing: Unit tests and also large ping in Fuchsia OS PiperOrigin-RevId: 246563592 Change-Id: Ia12ab619f64f4be2c8d346ce81341a91724aef95
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go91
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go270
3 files changed, 361 insertions, 2 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7a5341def..1b4f29e0c 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -28,11 +28,13 @@ go_test(
srcs = ["ipv4_test.go"],
deps = [
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c6af0db79..4edc52f19 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -107,6 +107,88 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
+// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
+// write. It assumes that the IP header is entirely in hdr but does not assume
+// that only the IP header is in hdr. It assumes that the input packet's stated
+// length matches the length of the hdr+payload. mtu includes the IP header and
+// options. This does not support the DontFragment IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+ // This packet is too big, it needs to be fragmented.
+ ip := header.IPv4(hdr.View())
+ flags := ip.Flags()
+
+ // Update mtu to take into account the header, which will exist in all
+ // fragments anyway.
+ innerMTU := mtu - int(ip.HeaderLength())
+
+ // Round the MTU down to align to 8 bytes. Then calculate the number of
+ // fragments. Calculate fragment sizes as in RFC791.
+ innerMTU &^= 7
+ n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
+
+ outerMTU := innerMTU + int(ip.HeaderLength())
+ offset := ip.FragmentOffset()
+ originalAvailableLength := hdr.AvailableLength()
+ for i := 0; i < n; i++ {
+ // Where possible, the first fragment that is sent has the same
+ // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
+ // on this for looking at, eg, L4 headers.
+ h := ip
+ if i > 0 {
+ hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
+ h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
+ copy(h, ip[:ip.HeaderLength()])
+ }
+ if i != n-1 {
+ h.SetTotalLength(uint16(outerMTU))
+ h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
+ } else {
+ h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetFlagsFragmentOffset(flags, offset)
+ }
+ h.SetChecksum(0)
+ h.SetChecksum(^h.CalculateChecksum())
+ offset += uint16(innerMTU)
+ if i > 0 {
+ newPayload := payload.Clone([]buffer.View{})
+ newPayload.CapLength(innerMTU)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayload.Size())
+ continue
+ }
+ // Special handling for the first fragment because it comes from the hdr.
+ if outerMTU >= hdr.UsedLength() {
+ // This fragment can fit all of hdr and possibly some of payload, too.
+ newPayload := payload.Clone([]buffer.View{})
+ newPayloadLength := outerMTU - hdr.UsedLength()
+ newPayload.CapLength(newPayloadLength)
+ if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ payload.TrimFront(newPayloadLength)
+ } else {
+ // The fragment is too small to fit all of hdr.
+ startOfHdr := hdr
+ startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
+ emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
+ if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
+ return err
+ }
+ r.Stats().IP.PacketsSent.Increment()
+ // Add the unused bytes of hdr into the payload that remains to be sent.
+ restOfHdr := hdr.View()[outerMTU:]
+ tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
+ tmp.Append(payload)
+ payload = tmp
+ }
+ }
+ return nil
+}
+
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
@@ -138,9 +220,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if loop&stack.PacketOut == 0 {
return nil
}
-
+ if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && gso.Type == stack.GSONone {
+ return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ }
+ if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 146143ab3..7a09ef6de 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,14 +15,19 @@
package ipv4_test
import (
+ "bytes"
+ "encoding/hex"
+ "math/rand"
"testing"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -90,3 +95,268 @@ func TestExcludeBroadcast(t *testing.T) {
}
})
}
+
+// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
+// data should already be in the header before WritePacket. extraLength
+// indicates how much extra space should be in the header. The payload is made
+// from many Views of the sizes listed in viewSizes.
+func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
+ hdr := buffer.NewPrependable(hdrLength + extraLength)
+ hdr.Prepend(hdrLength)
+ rand.Read(hdr.View())
+
+ var views []buffer.View
+ totalLength := 0
+ for _, s := range viewSizes {
+ newView := buffer.NewView(s)
+ rand.Read(newView)
+ views = append(views, newView)
+ totalLength += s
+ }
+ payload := buffer.NewVectorisedView(totalLength, views)
+ return hdr, payload
+}
+
+// comparePayloads compared the contents of all the packets against the contents
+// of the source packet.
+func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
+ t.Helper()
+ // Make a complete array of the sourcePacketInfo packet.
+ source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
+ source = append(source, sourcePacketInfo.Header.View()...)
+ source = append(source, sourcePacketInfo.Payload.ToView()...)
+
+ // Make a copy of the IP header, which will be modified in some fields to make
+ // an expected header.
+ sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetFlagsFragmentOffset(0, 0)
+ sourceCopy.SetTotalLength(0)
+ var offset uint16
+ // Build up an array of the bytes sent.
+ var reassembledPayload []byte
+ for i, packet := range packets {
+ // Confirm that the packet is valid.
+ allBytes := packet.Header.View().ToVectorisedView()
+ allBytes.Append(packet.Payload)
+ ip := header.IPv4(allBytes.ToView())
+ if !ip.IsValid(len(ip)) {
+ t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ }
+ if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
+ t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ }
+ if got, want := len(ip), int(mtu); got > want {
+ t.Errorf("fragment is too large, got %d want %d", got, want)
+ }
+ if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
+ t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ }
+ if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
+ t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ }
+ if i < len(packets)-1 {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ } else {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ }
+ reassembledPayload = append(reassembledPayload, ip.Payload()...)
+ offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ // Clear out the checksum and length from the ip because we can't compare
+ // it.
+ sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetChecksum(0)
+ sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
+ if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
+ t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ }
+ }
+ expected := source[source.HeaderLength():]
+ if !bytes.Equal(reassembledPayload, expected) {
+ t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+ }
+}
+
+type errorChannel struct {
+ *channel.Endpoint
+ Ch chan packetInfo
+ packetCollectorErrors []*tcpip.Error
+}
+
+// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
+// will return successive errors from packetCollectorErrors until the list is
+// empty and then return nil each time.
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
+ _, e := channel.New(size, mtu, linkAddr)
+ ec := errorChannel{
+ Endpoint: e,
+ Ch: make(chan packetInfo, size),
+ packetCollectorErrors: packetCollectorErrors,
+ }
+
+ return stack.RegisterLinkEndpoint(e), &ec
+}
+
+// packetInfo holds all the information about an outbound packet.
+type packetInfo struct {
+ Header buffer.Prependable
+ Payload buffer.VectorisedView
+}
+
+// Drain removes all outbound packets from the channel and counts them.
+func (e *errorChannel) Drain() int {
+ c := 0
+ for {
+ select {
+ case <-e.Ch:
+ c++
+ default:
+ return c
+ }
+ }
+}
+
+// WritePacket stores outbound packets into the channel.
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ p := packetInfo{
+ Header: hdr,
+ Payload: payload,
+ }
+
+ select {
+ case e.Ch <- p:
+ default:
+ }
+
+ nextError := (*tcpip.Error)(nil)
+ if len(e.packetCollectorErrors) > 0 {
+ nextError = e.packetCollectorErrors[0]
+ e.packetCollectorErrors = e.packetCollectorErrors[1:]
+ }
+ return nextError
+}
+
+type context struct {
+ stack.Route
+ linkEP *errorChannel
+}
+
+func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
+ // Make the packet and write it.
+ s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
+ _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ linkEPId := stack.RegisterLinkEndpoint(linkEP)
+ s.CreateNIC(1, linkEPId)
+ s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01")
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x10\x00\x00\x02",
+ Mask: "\xff\xff\xff\xff",
+ Gateway: "",
+ NIC: 1,
+ }})
+ r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return context{
+ Route: r,
+ linkEP: linkEP,
+ }
+}
+
+func TestFragmentation(t *testing.T) {
+ var manyPayloadViewsSizes [1000]int
+ for i := range manyPayloadViewsSizes {
+ manyPayloadViewsSizes[i] = 7
+ }
+ fragTests := []struct {
+ description string
+ mtu uint32
+ hdrLength int
+ extraLength int
+ payloadViewsSizes []int
+ expectedFrags int
+ }{
+ {"NoFragmentation", 2000, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"NoFragmentationWithBigHeader", 2000, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"Fragmented", 800, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithManyViews", 300, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithManyViewsAndPrependableBytes", 300, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"FragmentedWithBigHeader", 800, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"FragmentedWithBigHeaderAndPrependableBytes", 800, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ source := packetInfo{
+ Header: hdr,
+ // Save the source payload because WritePacket will modify it.
+ Payload: payload.Clone([]buffer.View{}),
+ }
+ c := buildContext(t, nil, ft.mtu)
+ err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ if err != nil {
+ t.Errorf("err got %v, want %v", err, nil)
+ }
+
+ var results []packetInfo
+ L:
+ for {
+ select {
+ case pi := <-c.linkEP.Ch:
+ results = append(results, pi)
+ default:
+ break L
+ }
+ }
+
+ if got, want := len(results), ft.expectedFrags; got != want {
+ t.Errorf("len(result) got %d, want %d", got, want)
+ }
+ if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want {
+ t.Errorf("no errors yet len(result) got %d, want %d", got, want)
+ }
+ compareFragments(t, results, source, ft.mtu)
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from write packet
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ fragTests := []struct {
+ description string
+ mtu uint32
+ hdrLength int
+ payloadViewsSizes []int
+ packetCollectorErrors []*tcpip.Error
+ }{
+ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}},
+ {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}},
+ {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}},
+ }
+
+ for _, ft := range fragTests {
+ t.Run(ft.description, func(t *testing.T) {
+ hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
+ c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
+ err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42)
+ for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
+ if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
+ t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
+ }
+ }
+ // We only need to check that last error because all the ones before are
+ // nil.
+ if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want {
+ t.Errorf("err got %v, want %v", got, want)
+ }
+ if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want {
+ t.Errorf("after linkEP error len(result) got %d, want %d", got, want)
+ }
+ })
+ }
+}