diff options
Diffstat (limited to 'pkg/tcpip')
35 files changed, 964 insertions, 391 deletions
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index e57d45f2a..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -22,7 +22,6 @@ go_test( size = "small", srcs = ["gonet_test.go"], library = ":gonet", - tags = ["flaky"], deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 76839eb92..62ac932bb 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } +// More returns whether the more fragments flag is set. +func (b IPv4) More() bool { + return b.Flags()&IPv4FlagMoreFragments != 0 +} + // TTL returns the "TTL" field of the ipv4 header. func (b IPv4) TTL() uint8 { return b[ttl] diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 2c4591409..3499d8399 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 { return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) } +// IsAtomic returns whether the fragment header indicates an atomic fragment. An +// atomic fragment is a fragment that contains all the data required to +// reassemble a full packet. +func (b IPv6FragmentExtHdr) IsAtomic() bool { + return !b.More() && b.FragmentOffset() == 0 +} + // IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. // // The IPv6 payload may contain IPv6 extension headers before any upper layer diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ea1acba83..7f27a840d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,11 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } - h := header.ARP(v) + h := header.ARP(pkt.NetworkHeader) if !h.IsValid() { return } @@ -209,6 +205,17 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(header.ARPSize) + return 0, false, true +} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index f42abc4bb..2982450f8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -81,8 +81,8 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t } } -// Process processes an incoming fragment belonging to an ID -// and returns a complete packet when all the packets belonging to that ID have been received. +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet when all the packets belonging to that ID have been received. func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d9b62f2db..7c8fb3e0a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -293,9 +293,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -382,10 +382,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: vv, - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -448,17 +445,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: frag1.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: frag2.ToVectorisedView(), - }) + pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -538,9 +535,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -652,12 +649,25 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } }) } } + +// truncatedPacket returns a PacketBuffer based on a truncated view. If view, +// after truncation, is large enough to hold a network header, it makes part of +// view the packet's NetworkHeader and the rest its Data. Otherwise all of view +// becomes Data. +func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { + v := view[:len(view)-trunc] + if len(v) < netHdrLen { + return &stack.PacketBuffer{Data: v.ToVectorisedView()} + } + return &stack.PacketBuffer{ + NetworkHeader: v[:netHdrLen], + Data: v[netHdrLen:].ToVectorisedView(), + } +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1c3ae835..1b67aa066 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -59,6 +59,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9cd7592f4..7e9f16c90 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,6 +21,7 @@ package ipv4 import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -258,38 +259,24 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than + // only NATted packets, but removing this check short circuits broadcasts + // before they are sent out to other hosts. if pkt.NatDone { - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. netHeader := header.IPv4(pkt.NetworkHeader) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - ep.HandlePacket(&route, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) return nil } } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - + e.HandlePacket(&loopedR, pkt) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { @@ -342,18 +329,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { + if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - ep.HandlePacket(&route, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + ep.HandlePacket(&route, pkt) n++ continue } @@ -427,22 +407,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { + h := header.IPv4(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv4(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - pkt.NetworkHeader = headerView[:h.HeaderLength()] - - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - pkt.Data.TrimFront(hlen) - pkt.Data.CapLength(tlen - hlen) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. @@ -452,9 +421,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 - if more || h.FragmentOffset() != 0 { - if pkt.Data.Size() == 0 { + if h.More() || h.FragmentOffset() != 0 { + if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -473,7 +441,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -485,7 +453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - headerView.CapLength(hlen) + pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -565,6 +533,35 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv4(hdr) + + // If there are options, pull those into hdr as well. + if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(headerLen) + if !ok { + panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) + } + ipHdr = header.IPv4(hdr) + } + + // If this is a fragment, don't bother parsing the transport header. + parseTransportHeader := true + if ipHdr.More() || ipHdr.FragmentOffset() != 0 { + parseTransportHeader = false + } + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return ipHdr.TransportProtocol(), parseTransportHeader, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index c208ebd99..11e579c4b 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -652,6 +652,18 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload1, udpPayload2}, }, + { + name: "Fragment without followup", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b62fb1de6..2ff7eedf4 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt *stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { received.Invalid.Increment() return } h := header.ICMPv6(v) - iph := header.IPv6(netHeader) + iph := header.IPv6(pkt.NetworkHeader) // Validate ICMPv6 checksum before processing the packet. // diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a720f626f..52a01b44e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -179,36 +179,32 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + handleIPv6Payload := func(icmp header.ICMPv6) { + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + NetworkHeader: buffer.View(ip), + Data: buffer.View(icmp).ToVectorisedView(), }) } for _, typ := range types { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - handleIPv6Payload(hdr) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -546,25 +542,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } handleIPv6Payload := func(checksum bool) { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(typ.size + extraDataLen), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, DstAddr: lladdr0, }) e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0d94ad122..95fbcf2d1 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,22 +171,20 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { + h := header.IPv6(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv6(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - - pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] - pkt.Data.TrimFront(header.IPv6MinimumSize) - pkt.Data.CapLength(int(h.PayloadLength())) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + // vv consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader) + vv.Append(pkt.Data) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false for firstHeader := true; ; firstHeader = false { @@ -262,9 +260,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6FragmentExtHdr: hasFragmentHeader = true - fragmentOffset := extHdr.FragmentOffset() - more := extHdr.More() - if !more && fragmentOffset == 0 { + if extHdr.IsAtomic() { // This fragment extension header indicates that this packet is an // atomic fragment. An atomic fragment is a fragment that contains // all the data required to reassemble a full packet. As per RFC 6946, @@ -277,9 +273,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. - rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */) - if fragmentOffset == 0 { + if extHdr.FragmentOffset() == 0 { // Check that the iterator ends with a raw payload as the first fragment // should include all headers up to and including any upper layer // headers, as per RFC 8200 section 4.5; only upper layer data @@ -332,7 +328,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } // The packet is a fragment, let's try to reassemble it. - start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit last := start + uint16(fragmentPayloadLen) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the @@ -345,7 +341,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } var ready bool - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + // Note that pkt doesn't have its transport header set after reassembly, + // and won't until DeliverNetworkPacket sets it. + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -394,10 +392,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6RawPayloadHeader: // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + + // For unfragmented packets, extHdr still contains the transport header. + // Get rid of it. + // + // For reassembled fragments, pkt.TransportHeader is unset, so this is a + // no-op and pkt.Data begins with the transport header. + extHdr.Buf.TrimFront(len(pkt.TransportHeader)) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt, hasFragmentHeader) + e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error @@ -505,6 +510,79 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + // + // Parsing occurs again in HandlePacket because we don't track the + // extensions in PacketBuffer. Unfortunately, that means HandlePacket + // has to do the parsing work again. + var nextHdr tcpip.TransportProtocolNumber + foundNext := true + extensionsSize := 0 +traverseExtensions: + for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { + if err != nil { + break + } + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + foundNext = false + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + // If this is an atomic fragment, we don't have to treat it specially. + if !extHdr.More() && extHdr.FragmentOffset() == 0 { + continue + } + // This is a non-atomic fragment and has to be re-assembled before we can + // examine the payload for a transport header. + foundNext = false + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader. + hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + + return nextHdr, foundNext, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 3c141b91b..64239ce9a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) { return s, ep, r } - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View if atomicFragment { - bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) - bytes[0] = nextHdr + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } ep.HandlePacket(r, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + NetworkHeader: buffer.View(ip), + Data: payload.ToVectorisedView(), }) } @@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) { invalid := stats.Invalid typStat := typ.statCounter(stats) - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetCode(test.code) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f71073207..afca925ad 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -110,5 +110,6 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index d4053be08..05bf62788 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -20,7 +20,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" @@ -147,44 +146,6 @@ type ConnTrackTable struct { Seed uint32 } -// parseHeaders sets headers in the packet. -func parseHeaders(pkt *PacketBuffer) { - newPkt := pkt.Clone() - - // Set network header. - hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - netHeader := header.IPv4(hdr) - newPkt.NetworkHeader = hdr - length := int(netHeader.HeaderLength()) - - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - // Set transport header. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) - } - case header.TCPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) - } - } - pkt.NetworkHeader = newPkt.NetworkHeader - pkt.TransportHeader = newPkt.TransportHeader -} - // packetToTuple converts packet to a tuple in original direction. func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple @@ -257,13 +218,6 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { // TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other // transport protocols. func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - if hook == Prerouting { - // Headers will not be set in Prerouting. - // TODO(gvisor.dev/issue/170): Change this after parsing headers - // code is added. - parseHeaders(pkt) - } - var dir ctDirection tuple, err := packetToTuple(pkt, hook) if err != nil { diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 63537aaad..a6546cef0 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -33,6 +33,10 @@ const ( // except where another value is explicitly used. It is chosen to match // the MTU of loopback interfaces on linux systems. fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -69,15 +73,8 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { } func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { - // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fwdTestNetHeaderLen) - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -100,9 +97,9 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = f.id.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } @@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { } func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = netHeader + pkt.Data.TrimFront(fwdTestNetHeaderLen) + return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -361,7 +368,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -398,7 +405,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -429,7 +436,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -459,7 +466,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. buf := buffer.NewView(30) - buf[0] = 4 + buf[dstAddrOffset] = 4 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -467,7 +474,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -509,7 +515,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -554,7 +559,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ @@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). + if !ok { + t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) } - // The first 5 packets should not be forwarded so the the - // sequemnce number should start with 5. + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. want := uint16(i + 5) - if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { t.Fatalf("got the packet #%d, want = #%d", n, want) } @@ -609,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Each packet has a different destination address (3 to // maxPendingResolutions + 7). buf := buffer.NewView(30) - buf[0] = byte(3 + i) + buf[dstAddrOffset] = byte(3 + i) ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() - if b[0] < 8 { - t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 36cc6275d..92e31643e 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set network header. - if hook == Prerouting { - parseHeaders(pkt) - } - // Drop the packet if network and transport header are not set. if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { return RuleDrop, 0 diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index ae7a8f740..e28c23d66 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -467,8 +467,17 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - // The timer used to send the next router solicitation message. - rtrSolicitTimer *time.Timer + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer *time.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this ndpState is associated with. MUST be set when the timer is + // set. + done *bool + } // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -648,13 +657,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. timer = time.AfterFunc(0, func() { - ndp.nic.mu.RLock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if done { // If we reach this point, it means that the DAD timer fired after // another goroutine already obtained the NIC lock and stopped DAD // before this function obtained the NIC lock. Simply return here and do // nothing further. - ndp.nic.mu.RUnlock() return } @@ -665,15 +675,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } dadDone := remaining == 0 - ndp.nic.mu.RUnlock() var err *tcpip.Error if !dadDone { - err = ndp.sendDADPacket(addr) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the NIC. Note, DAD would be + // stopped if the NIC was disabled or removed, or if the address was + // removed. + ndp.nic.mu.Unlock() + err = ndp.sendDADPacket(addr, ref) + ndp.nic.mu.Lock() } - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. @@ -721,17 +739,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // addr. // // addr must be a tentative IPv6 address on ndp's NIC. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { +// +// The NIC ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) @@ -1816,7 +1841,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicitTimer != nil { + if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. return } @@ -1833,14 +1858,27 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the NIC lock and stopped solicitations. + // Simply return here and do nothing further. + ndp.nic.mu.Unlock() + return + } + // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) if ref == nil { - ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) } + ndp.nic.mu.Unlock() + localAddr := ref.ep.ID().LocalAddress r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() @@ -1849,6 +1887,13 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) @@ -1893,17 +1938,18 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - if remaining == 0 { - ndp.rtrSolicitTimer = nil - } else if ndp.rtrSolicitTimer != nil { + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but // we still reached this point, then we know the NIC // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. - ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } + ndp.nic.mu.Unlock() }) } @@ -1913,13 +1959,15 @@ func (ndp *ndpState) startSolicitingRouters() { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicitTimer == nil { + if ndp.rtrSolicit.timer == nil { // Nothing to do. return } - ndp.rtrSolicitTimer.Stop() - ndp.rtrSolicitTimer = nil + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil } // initializeTempAddrState initializes state related to temporary SLAAC diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ec8e3cb85..644c0d437 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -457,8 +457,20 @@ type ipv6AddrCandidate struct { // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { n.mu.RLock() - defer n.mu.RUnlock() + ref := n.primaryIPv6EndpointRLocked(remoteAddr) + n.mu.RUnlock() + return ref +} +// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +// +// n.mu MUST be read locked. +func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { @@ -568,11 +580,6 @@ const ( // promiscuous indicates that the NIC's promiscuous flag should be observed // when getting a NIC's referenced network endpoint. promiscuous - - // forceSpoofing indicates that the NIC should be assumed to be spoofing, - // regardless of what the NIC's spoofing flag is when getting a NIC's - // referenced network endpoint. - forceSpoofing ) func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { @@ -591,8 +598,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { - id := NetworkEndpointID{address} - n.mu.RLock() var spoofingOrPromiscuous bool @@ -601,11 +606,9 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t spoofingOrPromiscuous = n.mu.spoofing case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous - case forceSpoofing: - spoofingOrPromiscuous = true } - if ref, ok := n.mu.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -654,11 +657,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.mu.endpoints[id]; ok { + ref := n.getRefOrCreateTempLocked(protocol, address, peb) + n.mu.Unlock() + return ref +} + +/// getRefOrCreateTempLocked returns an existing endpoint for address or creates +/// and returns a temporary endpoint. +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { - n.mu.Unlock() return ref } // tryIncRef failing means the endpoint is scheduled to be removed once the @@ -670,7 +680,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - n.mu.Unlock() return nil } ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ @@ -681,7 +690,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t }, }, peb, temporary, static, false) - n.mu.Unlock() return ref } @@ -1212,12 +1220,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + // Parse headers. + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { + // The packet is too small to contain a network header. n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + if hasTransportHdr { + // Parse the transport header if present. + if state, ok := n.stack.transportProtocols[transProtoNum]; ok { + state.proto.Parse(pkt) + } + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1229,7 +1246,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. - if protocol == header.IPv4ProtocolNumber { + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) @@ -1300,8 +1318,18 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this + // by having lower layers explicity write each header instead of just + // pkt.Header. + + // pkt may have set its NetworkHeader and TransportHeader. If we're + // forwarding, we'll have to copy them into pkt.Header. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) + if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) + } + if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) } // WritePacket takes ownership of pkt, calculate numBytes first. @@ -1332,13 +1360,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // from fragments. + if pkt.TransportHeader == nil { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + pkt.TransportHeader = transHeader + } else { + // This is either a bad packet or was re-assembled from fragments. + transProto.Parse(pkt) + } + } + + if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index fea46158c..31f865260 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,268 @@ package stack import ( + "math" "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +var _ LinkEndpoint = (*testLinkEndpoint)(nil) + +// A LinkEndpoint that throws away outgoing packets. +// +// We use this instead of the channel endpoint as the channel package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testLinkEndpoint struct { + dispatcher NetworkDispatcher +} + +// Attach implements LinkEndpoint.Attach. +func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements LinkEndpoint.IsAttached. +func (e *testLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements LinkEndpoint.MTU. +func (*testLinkEndpoint) MTU() uint32 { + return math.MaxUint16 +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { + return CapabilityResolutionRequired +} + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*testLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements LinkEndpoint.Wait. +func (*testLinkEndpoint) Wait() {} + +// WritePacket implements LinkEndpoint.WritePacket. +func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) + +// An IPv6 NetworkEndpoint that throws away outgoing packets. +// +// We use this instead of ipv6.endpoint because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Endpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + linkEP LinkEndpoint + protocol *testIPv6Protocol +} + +// DefaultTTL implements NetworkEndpoint.DefaultTTL. +func (*testIPv6Endpoint) DefaultTTL() uint8 { + return 0 +} + +// MTU implements NetworkEndpoint.MTU. +func (e *testIPv6Endpoint) MTU() uint32 { + return e.linkEP.MTU() - header.IPv6MinimumSize +} + +// Capabilities implements NetworkEndpoint.Capabilities. +func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. +func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket implements NetworkEndpoint.WritePacket. +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements NetworkEndpoint.WritePackets. +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteHeaderIncludedPacket implements +// NetworkEndpoint.WriteHeaderIncludedPacket. +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ID implements NetworkEndpoint.ID. +func (e *testIPv6Endpoint) ID() *NetworkEndpointID { + return &e.id +} + +// PrefixLen implements NetworkEndpoint.PrefixLen. +func (e *testIPv6Endpoint) PrefixLen() int { + return e.prefixLen +} + +// NICID implements NetworkEndpoint.NICID. +func (e *testIPv6Endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// HandlePacket implements NetworkEndpoint.HandlePacket. +func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { +} + +// Close implements NetworkEndpoint.Close. +func (*testIPv6Endpoint) Close() {} + +// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. +func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +var _ NetworkProtocol = (*testIPv6Protocol)(nil) + +// An IPv6 NetworkProtocol that supports the bare minimum to make a stack +// believe it supports IPv6. +// +// We use this instead of ipv6.protocol because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Protocol struct{} + +// Number implements NetworkProtocol.Number. +func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. +func (*testIPv6Protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. +func (*testIPv6Protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint implements NetworkProtocol.NewEndpoint. +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &testIPv6Endpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + protocol: p, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { + return nil +} + +// Option implements NetworkProtocol.Option. +func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. +func (*testIPv6Protocol) Close() {} + +// Wait implements NetworkProtocol.Wait. +func (*testIPv6Protocol) Wait() {} + +// Parse implements NetworkProtocol.Parse. +func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + return 0, false, false +} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) + +// LinkAddressProtocol implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { + return nil +} + +// ResolveStaticAddress implements LinkAddressResolver. +func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return "", false +} + +// Test the race condition where a NIC is removed and an RS timer fires at the +// same time. +func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { + const ( + nicID = 1 + + maxRtrSolicitations = 5 + ) + + e := testLinkEndpoint{} + s := New(Options{ + NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NDPConfigs: NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: minimumRtrSolicitationInterval, + }, + }) + + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) + } + + s.mu.Lock() + // Wait for the router solicitation timer to fire and block trying to obtain + // the stack lock when doing link address resolution. + time.Sleep(minimumRtrSolicitationInterval * 2) + if err := s.removeNICLocked(nicID); err != nil { + t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) + } + s.mu.Unlock() +} + func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 94f177841..5cbc946b6 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -168,6 +168,11 @@ type TransportProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does + // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() < + // MinimumPacketSize() + Parse(pkt *PacketBuffer) (ok bool) } // TransportDispatcher contains the methods used by the network stack to deliver @@ -313,6 +318,14 @@ type NetworkProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It + // returns: + // - The encapsulated protocol, if present. + // - Whether there is an encapsulated transport protocol payload (e.g. ARP + // does not encapsulate anything). + // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. + Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f5b6ca0b9..d65f8049e 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -113,6 +113,8 @@ func (r *Route) GSOMaxSize() uint32 { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). +// +// The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or @@ -148,6 +150,8 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. +// +// The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 294ce8775..648791302 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1017,6 +1017,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() + return s.removeNICLocked(id) +} + +// removeNICLocked removes NIC and all related routes from the network stack. +// +// s.mu must be locked. +func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { nic, ok := s.nics[id] if !ok { return tcpip.ErrUnknownNICID diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f6ddc3ced..ffef9bc2c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -52,6 +52,10 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -94,26 +98,24 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fakeNetHeaderLen) - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) + f.dispatcher.DeliverTransportControlPacket( + tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + fakeNetNumber, + tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -138,18 +140,13 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) + pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] + pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] + pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + f.HandlePacket(r, pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { } func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { @@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(fakeNetHeaderLen) + return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true +} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -292,7 +300,7 @@ func TestNetworkReceive(t *testing.T) { buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. - buf[0] = 3 + buf[dstAddrOffset] = 3 ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -304,7 +312,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to first endpoint. - buf[0] = 1 + buf[dstAddrOffset] = 1 ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -316,7 +324,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to second endpoint. - buf[0] = 2 + buf[dstAddrOffset] = 2 ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) { buf := buffer.NewView(30) // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) @@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte if promiscuous { if err := s.SetPromiscuousMode(nicID, true); err != nil { @@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. @@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -2344,7 +2352,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) - buf[0] = dstAddr[0] + buf[dstAddrOffset] = dstAddr[0] ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cb350ead3..ad61c09d6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -83,7 +83,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) + hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err @@ -324,6 +325,17 @@ func (*fakeTransportProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeTransportProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) + if !ok { + return false + } + pkt.TransportHeader = hdr + pkt.Data.TrimFront(fakeTransHeaderLen) + return true +} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } @@ -642,11 +654,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.NetworkHeader[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.NetworkHeader[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 2ec6749c7..74ef6541e 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -124,6 +124,16 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. + // + // Right now, the Parse() method is tied to enabled protocols passed into + // stack.New. This works for UDP and TCP, but we handle ICMP traffic even + // when netstack users don't pass ICMP as a supported protocol. + return false +} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 21c34fac2..a406d815e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -627,8 +627,9 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { }, } - networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) - combinedVV := networkHeader.ToVectorisedView() + headers := append(buffer.View(nil), pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV := headers.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f38eb6833..e26f01fae 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -86,10 +86,6 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - # FIXME(b/68809571) - tags = [ - "flaky", - ], deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index edca98160..19f7bf449 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -63,7 +63,8 @@ const ( StateClosing ) -// connected is the set of states where an endpoint is connected to a peer. +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. func (s EndpointState) connected() bool { switch s { case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: @@ -73,6 +74,40 @@ func (s EndpointState) connected() bool { } } +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + // String implements fmt.Stringer.String. func (s EndpointState) String() string { switch s { diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index fc43c11e2..cbb779666 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,11 +49,10 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.EndpointState() { - case StateInitial, StateBound: - // TODO(b/138137272): this enumeration duplicates - // EndpointState.connected. remove it. - case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + epState := e.EndpointState() + switch { + case epState == StateInitial || epState == StateBound: + case epState.connected() || epState.handshake(): if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) @@ -69,15 +68,16 @@ func (e *endpoint) beforeSave() { break } fallthrough - case StateListen, StateConnecting: + case epState == StateListen || epState == StateConnecting: e.drainSegmentLocked() - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + // Refresh epState, since drainSegmentLocked may have changed it. + epState = e.EndpointState() + if !epState.closed() { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } - break } - case StateError, StateClose: + case epState.closed(): for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) @@ -148,23 +148,23 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state EndpointState) { +func (e *endpoint) loadState(epState EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. // For restore purposes we treat TimeWait like a connected endpoint. - if state.connected() || state == StateTimeWait { + if epState.connected() || epState == StateTimeWait { connectedLoading.Add(1) } - switch state { - case StateListen: + switch { + case epState == StateListen: listenLoading.Add(1) - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): connectingLoading.Add(1) } // Directly update the state here rather than using e.setEndpointState // as the endpoint is still being loaded and the stack reference is not // yet initialized. - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) } // afterLoad is invoked by stateify. @@ -183,8 +183,8 @@ func (e *endpoint) afterLoad() { func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) - state := e.origEndpointState - switch state { + epState := e.origEndpointState + switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -208,8 +208,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - switch state { - case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + switch { + case epState.connected(): bind() if len(e.connectingAddress) == 0 { e.connectingAddress = e.ID.RemoteAddress @@ -232,13 +232,13 @@ func (e *endpoint) Resume(s *stack.Stack) { closed := e.closed e.mu.Unlock() e.notifyProtocolGoroutine(notifyTickleWorker) - if state == StateFinWait2 && closed { + if epState == StateFinWait2 && closed { // If the endpoint has been closed then make sure we notify so // that the FIN_WAIT2 timer is started after a restore. e.notifyProtocolGoroutine(notifyClose) } connectedLoading.Done() - case StateListen: + case epState == StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -255,7 +255,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case StateConnecting, StateSynSent, StateSynRecv: + case epState.connecting(): tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -267,7 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case StateBound: + case epState == StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -276,7 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) { bind() tcpip.AsyncLoading.Done() }() - case StateClose: + case epState == StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -291,12 +291,11 @@ func (e *endpoint) Resume(s *stack.Stack) { e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) - case StateError: + case epState == StateError: e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } - } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c827d0277..73b8a6782 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "fmt" "runtime" "strings" "time" @@ -490,6 +491,26 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { return &p.synRcvdCount } +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(offset) + if !ok { + panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) + } + } + + pkt.TransportHeader = hdr + pkt.Data.TrimFront(len(hdr)) + return true +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 0c099e2fd..0280892a8 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -35,6 +35,7 @@ type segment struct { id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -67,6 +68,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) + s.hdr = header.TCP(pkt.TransportHeader) s.rcvdTime = time.Now() return s } @@ -146,12 +148,6 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) - // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -162,16 +158,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(s.hdr.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(s.hdr) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -180,22 +172,19 @@ func (s *segment) parse() bool { if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { s.csumValid = true verifyChecksum = false - s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) - s.data.TrimFront(offset) + s.csum = s.hdr.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(s.hdr.AckNumber()) + s.flags = s.hdr.Flags() + s.window = seqnum.Size(s.hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 663af8fec..8c7895713 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1270,16 +1270,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - pkt.Data.TrimFront(header.UDPMinimumSize) - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index e320c5758..4218e7d03 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -67,14 +67,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -121,7 +115,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -130,9 +124,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) - payload := newNetHeader.ToVectorisedView() - payload.Append(pkt.Data.ToView().ToVectorisedView()) + newHeader := append(buffer.View(nil), pkt.NetworkHeader...) + newHeader = append(newHeader, pkt.TransportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -141,8 +136,9 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload, + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) case header.IPv6AddressSize: @@ -164,11 +160,11 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) payload.Append(pkt.Data) payload.CapLength(payloadLen) @@ -178,8 +174,9 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload, + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) } return true @@ -201,6 +198,18 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Packet is too small + return false + } + pkt.TransportHeader = h + pkt.Data.TrimFront(header.UDPMinimumSize) + return true +} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e8ade882b..313a3f117 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -441,9 +441,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + Data: buf.ToVectorisedView(), }) } @@ -488,9 +486,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + Data: buf.ToVectorisedView(), }) } @@ -1720,6 +1716,58 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { } } +// TestShortHeader verifies that when a packet with a too-short UDP header is +// received, the malformed received global stat gets incremented. +func TestShortHeader(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + + // Allocate a buffer for an IPv6 and too-short UDP header. + const udpSize = header.UDPMinimumSize - 1 + buf := buffer.NewView(header.IPv6MinimumSize + udpSize) + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: header.UDPMinimumSize, + }) + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + // Copy all but the last byte of the UDP header into the packet. + copy(buf[header.IPv6MinimumSize:], udpHdr) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { |