diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 51 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 242 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 49 |
4 files changed, 220 insertions, 130 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 58e537aad..880ea7de2 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,10 +8,7 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 50b363dc4..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,8 +24,12 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,16 +51,16 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + pkt.Data.TrimFront(hlen) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } @@ -73,20 +77,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -95,7 +102,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } @@ -104,19 +115,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1339f8474..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -47,25 +47,27 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation protocol *protocol + stack *stack.Stack } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, + stack: st, } return e, nil @@ -89,7 +91,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. @@ -117,13 +119,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -137,68 +140,83 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) @@ -218,28 +236,43 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) - - if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt); !ok { + // iptables is telling us to drop the packet. + return nil + } + + if r.Loop&stack.PacketLoop != 0 { + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() @@ -247,34 +280,65 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } - if loop&stack.PacketOut == 0 { - return len(hdrs), nil + if r.Loop&stack.PacketOut == 0 { + return pkts.Len(), nil } - for i := range hdrs { - e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params) + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() + } + + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + dropped := ipt.CheckPackets(stack.Output, pkts) + if len(dropped) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -288,7 +352,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) @@ -300,36 +364,52 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + if r.Loop&stack.PacketLoop != 0 { + e.HandlePacket(r, pkt.Clone()) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) + + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt); !ok { + // iptables is telling us to drop the packet. + return + } more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { - if vv.Size() == 0 { + if pkt.Data.Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -337,10 +417,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and vv.size() causes a wrap - // around resulting in last being less than the offset. + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. if last < h.FragmentOffset() { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -348,7 +428,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } var ready bool var err error - vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -361,11 +441,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. @@ -434,6 +514,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 99f84acd7..5a864d832 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -113,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -183,17 +183,11 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan packetInfo, size), + Ch: make(chan stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 @@ -208,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -292,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []stack.PacketBuffer L: for { select { @@ -345,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -451,7 +446,7 @@ func TestInvalidFragments(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - const nicid tcpip.NICID = 42 + const nicID tcpip.NICID = 42 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ ipv4.NewProtocol(), @@ -461,10 +456,12 @@ func TestInvalidFragments(t *testing.T) { var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicid, sniffer.New(ep)) + s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { |