diff options
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r-- | pkg/tcpip/network/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/BUILD | 7 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/network/arp/arp_test.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/BUILD | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/fragmentation.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/fragmentation_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/network/fragmentation/reassembler.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 7 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 278 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 195 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/BUILD | 9 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 254 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 617 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 147 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 270 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 219 |
19 files changed, 1871 insertions, 401 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index f36f49453..9d16ff8c9 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d95d44f56..e7617229b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -6,9 +7,7 @@ go_library( name = "arp", srcs = ["arp.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index fd6395fc1..da8482509 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -16,9 +16,9 @@ // IPv4 addresses into link-local MAC addresses, and advertises IPv4 // addresses of its stack with the local network. // -// To use it in the networking stack, pass arp.ProtocolName as one of the -// network protocols when calling stack.New. Then add an "arp" address to -// every NIC on the stack that should respond to ARP requests. That is: +// To use it in the networking stack, pass arp.NewProtocol() as one of the +// network protocols when calling stack.New. Then add an "arp" address to every +// NIC on the stack that should respond to ARP requests. That is: // // if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { // // handle err @@ -33,9 +33,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ARP protocol name. - ProtocolName = "arp" - // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber @@ -45,7 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -61,7 +58,7 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -82,16 +79,21 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return @@ -100,19 +102,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) + fallthrough // also fill the cache from requests case header.ARPReply: + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -129,12 +137,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ - nicid: nicid, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, }, nil @@ -159,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -200,8 +210,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an ARP network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 4c4b54469..8e6048a21 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -44,14 +44,19 @@ type testContext struct { } func newTestContext(t *testing.T) *testContext { - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + }) const defaultMTU = 65536 - id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { - id = sniffer.New(id) + wep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext { return &testContext{ t: t, s: s, - linkEP: linkEP, + linkEP: ep, } } @@ -97,19 +102,21 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi := <-c.linkEP.C + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.Header.View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 118bfc763..acf1e022c 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,7 +1,8 @@ -package(licenses = ["notice"]) - +load("@io_bazel_rules_go//go:def.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") + +package(licenses = ["notice"]) go_template_instance( name = "reassembler_list", @@ -24,9 +25,10 @@ go_library( "reassembler_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/tcpip", "//pkg/tcpip/buffer", ], ) @@ -42,11 +44,3 @@ go_test( embed = [":fragmentation"], deps = ["//pkg/tcpip/buffer"], ) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1628a82be..6da5238ec 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,6 +17,7 @@ package fragmentation import ( + "fmt" "log" "sync" "time" @@ -82,7 +83,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t // Process processes an incoming fragment belonging to an ID // and returns a complete packet when all the packets belonging to that ID have been received. -func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { @@ -97,8 +98,15 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } f.mu.Unlock() - res, done, consumed := r.process(first, last, more, vv) - + res, done, consumed, err := r.process(first, last, more, vv) + if err != nil { + // We probably got an invalid sequence of fragments. Just + // discard the reassembler and move on. + f.mu.Lock() + f.release(r) + f.mu.Unlock() + return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + } f.mu.Lock() f.size += consumed if done { @@ -114,7 +122,7 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf } } f.mu.Unlock() - return res, done + return res, done, nil } func (f *Fragmentation) release(r *reassembler) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 799798544..72c0f53be 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -83,7 +83,10 @@ func TestFragmentationProcess(t *testing.T) { t.Run(c.comment, func(t *testing.T) { f := NewFragmentation(1024, 512, DefaultReassembleTimeout) for i, in := range c.in { - vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + if err != nil { + t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + } if !reflect.DeepEqual(vv, c.out[i].vv) { t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) } @@ -114,7 +117,10 @@ func TestReassemblingTimeout(t *testing.T) { time.Sleep(2 * timeout) // Send another fragment that completes a packet. // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done := f.Process(0, 1, 1, false, vv(1, "1")) + _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) + if err != nil { + t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + } if done { t.Errorf("Fragmentation does not respect the reassembling timeout.") } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 8037f734b..9e002e396 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +86,7 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,13 +96,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed + return buffer.VectorisedView{}, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) + return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) } - return res, true, consumed + return res, true, consumed, nil } func (r *reassembler) tooOld(timeout time.Duration) bool { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4b3bd74fa..4144a7837 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -144,32 +144,47 @@ func (*testObject) LinkAddress() tcpip.LinkAddress { return "" } +// Wait implements stack.LinkEndpoint.Wait. +func (*testObject) Wait() {} + // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + return tcpip.ErrNotSupported +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -182,7 +197,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { - s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) s.CreateNIC(1, loopback.New()) s.AddAddress(1, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ @@ -221,7 +239,10 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -261,7 +282,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -349,7 +372,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -412,13 +437,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -451,7 +480,10 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), 123, 123, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -491,7 +523,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -501,6 +535,7 @@ func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" cases := []struct { name string expectedCount int @@ -552,7 +587,7 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + SrcAddr: outerSrcAddr, DstAddr: localIpv6Addr, }) @@ -599,8 +634,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index be84fa63d..aeddfcdd4 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -9,9 +10,7 @@ go_library( "ipv4.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index a25756443..32bf39e43 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -95,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, r.DefaultTTL()); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } @@ -104,19 +112,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b7a06f525..e645cf62c 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -14,9 +14,9 @@ // Package ipv4 contains the implementation of the ipv4 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv4.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv4.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv4.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv4 @@ -32,9 +32,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv4 protocol name. - ProtocolName = "ipv4" - // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -42,28 +39,33 @@ const ( // TotalLength field of the ipv4 header. MaxTotalSize = 0xffff + // DefaultTTL is the default time-to-live value for this endpoint. + DefaultTTL = 64 + // buckets is the number of identifier buckets. buckets = 2048 ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation + protocol *protocol } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, } return e, nil @@ -71,7 +73,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi // DefaultTTL is the default time-to-live value for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -87,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. @@ -115,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -135,122 +138,167 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payload.Size()) + length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), - TTL: ttl, - Protocol: uint8(protocol), + TTL: params.TTL, + TOS: params.TOS, + Protocol: uint8(params.Protocol), SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() return nil } +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("multiple packets in local loop") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + ip := header.IPv4(pkt.Data.First()) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -264,10 +312,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) } ip.SetID(uint16(id)) } @@ -277,37 +325,63 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, pkt.Clone()) } if loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { + if pkt.Data.Size() == 0 { + // Drop the packet as it's marked as a fragment but has + // no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < h.FragmentOffset() { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } var ready bool - vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + var err error + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } if !ready { return } @@ -315,24 +389,24 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() {} -type protocol struct{} +type protocol struct { + ids []uint32 + hashIV uint32 -// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv4 protocol number. @@ -358,12 +432,34 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -378,7 +474,7 @@ func calculateMTU(mtu uint32) uint32 { // hashRoute calculates a hash value for the given route. It uses the source & // destination address, the transport protocol number, and a random initial // value (generated once on initialization) to generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { t := r.LocalAddress a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 t = r.RemoteAddress @@ -386,22 +482,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -var ( - ids []uint32 - hashIV uint32 -) - -func init() { - ids = make([]uint32, buckets) +// NewProtocol returns an IPv4 network protocol. +func NewProtocol() stack.NetworkProtocol { + ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. r := hash.RandN32(1 + buckets) for i := range ids { ids[i] = r[i] } - hashIV = r[buckets] + hashIV := r[buckets] - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) + return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1b5a55bea..e900f1b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -33,24 +33,20 @@ import ( ) func TestExcludeBroadcast(t *testing.T) { - s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) const defaultMTU = 65536 - id, _ := channel.New(256, defaultMTU, "") + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { - id = sniffer.New(id) + ep = sniffer.New(ep) } - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, ep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { - t.Fatalf("AddAddress failed: %v", err) - } - s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: 1, @@ -117,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -136,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -177,28 +173,19 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan tcpip.PacketBuffer packetCollectorErrors []*tcpip.Error } // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { - _, e := channel.New(size, mtu, linkAddr) - ec := errorChannel{ - Endpoint: e, - Ch: make(chan packetInfo, size), +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), + Ch: make(chan tcpip.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } - - return stack.RegisterLinkEndpoint(e), &ec -} - -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView } // Drain removes all outbound packets from the channel and counts them. @@ -215,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -241,10 +223,11 @@ type context struct { func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. - s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) - _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - linkEPId := stack.RegisterLinkEndpoint(linkEP) - s.CreateNIC(1, linkEPId) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -266,7 +249,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 } return context{ Route: r, - linkEP: linkEP, + linkEP: ep, } } @@ -298,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := tcpip.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []tcpip.PacketBuffer L: for { select { @@ -351,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -368,3 +357,119 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestInvalidFragments(t *testing.T) { + // These packets have both IHL and TotalLength set to 0. + testCases := []struct { + name string + packets [][]byte + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + "ihl_totallen_zero_valid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + "ihl_totallen_zero_invalid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + // Total Length of 37(20 bytes IP header + 17 bytes of + // payload) + // Frag Offset of 0x1ffe = 8190*8 = 65520 + // Leading to the fragment end to be past 65535. + "ihl_totallen_valid_invalid_frag_offset_1", + [][]byte{ + {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + // The following 3 tests were found by running a fuzzer and were + // triggering a panic in the IPv4 reassembler code. + { + "ihl_less_than_ipv4_minimum_size_1", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_2", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_3", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "fragment_with_short_total_len_extra_payload", + [][]byte{ + {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + { + "multiple_fragments_with_more_fragments_set_to_false", + [][]byte{ + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + 1, + 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + const nicID tcpip.NICID = 42 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + }, + }) + + var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) + var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) + ep := channel.New(10, 1500, linkAddr) + s.CreateNIC(nicID, sniffer.New(ep)) + + for _, pkt := range tc.packets { + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index c71b69123..e4e273460 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,4 +1,5 @@ -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") package(licenses = ["notice"]) @@ -9,9 +10,7 @@ go_library( "ipv6.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -25,6 +24,7 @@ go_test( size = "small", srcs = [ "icmp_test.go", + "ipv6_test.go", "ndp_test.go", ], embed = [":ipv6"], @@ -36,6 +36,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b4d0295bf..1c3410618 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -21,21 +21,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -const ( - // ndpHopLimit is the expected IP hop limit value of 255 for received - // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, - // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently - // drop the NDP packet. All outgoing NDP packets must use this value for - // its IP hop limit field. - ndpHopLimit = 255 -) - // handleControl handles the case when an ICMP packet contains the headers of // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -49,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -61,35 +52,54 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } h := header.ICMPv6(v) + iph := header.IPv6(netHeader) + + // Validate ICMPv6 checksum before processing the packet. + // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. + // This copy is used as extra payload during the checksum calculation. + payload := pkt.Data + payload.RemoveFirst() + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. switch h.Type() { case header.ICMPv6NeighborSolicit, header.ICMPv6NeighborAdvert, header.ICMPv6RouterSolicit, header.ICMPv6RouterAdvert, header.ICMPv6RedirectMsg: - if header.IPv6(netHeader).HopLimit() != ndpHopLimit { + if iph.HopLimit() != header.NDPHopLimit { + received.Invalid.Increment() + return + } + + if h.Code() != 0 { received.Invalid.Increment() return } @@ -103,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -113,33 +123,80 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + + ns := header.NDPNeighborSolicit(h.NDPPayload()) + targetAddr := ns.TargetAddress() + s := r.Stack() + rxNICID := r.NICID() + + isTentative, err := s.IsAddrTentative(rxNICID, targetAddr) + if err != nil { + // We will only get an error if rxNICID is unrecognized, + // which should not happen. For now short-circuit this + // packet. + // + // TODO(b/141002840): Handle this better? + return + } + + if isTentative { + // If the target address is tentative and the source + // of the packet is a unicast (specified) address, then + // the source of the packet is attempting to perform + // address resolution on the target. In this case, the + // solicitation is silently ignored, as per RFC 4862 + // section 5.4.3. + // + // If the target address is tentative and the source of + // the packet is the unspecified address (::), then we + // know another node is also performing DAD for the + // same address (since targetAddr is tentative for us, + // we know we are also performing DAD on it). In this + // case we let the stack know so it can handle such a + // scenario and do nothing further with the NDP NS. + if iph.SourceAddress() == header.IPv6Any { + s.DupTentativeAddrDetected(rxNICID, targetAddr) + } + + // Do not handle neighbor solicitations targeted + // to an address that is tentative on the received + // NIC any further. + return + } + + // At this point we know that targetAddr is not tentative on + // rxNICID so the packet is processed as defined in RFC 4861, + // as per RFC 4862 section 5.4.3. + + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag - copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr) - pkt[icmpV6OptOffset] = ndpOptDstLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:]) + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) // ICMPv6 Neighbor Solicit messages are always sent to // specially crafted IPv6 multicast addresses. As a result, the @@ -152,9 +209,26 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + // TODO(tamird/ghanan): there exists an explicit NDP option that is + // used to update the neighbor table with link addresses for a + // neighbor from an NS (see the Source Link Layer option RFC + // 4861 section 4.6.1 and section 7.2.3). + // + // Furthermore, the entirety of NDP handling here seems to be + // contradicted by RFC 4861. + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) + + // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) + // + // 7.1.2. Validation of Neighbor Advertisements + // + // The IP Hop Limit field has a value of 255, i.e., the packet + // could not possibly have been forwarded by a router. + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return } @@ -166,10 +240,45 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize]) - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + + na := header.NDPNeighborAdvert(h.NDPPayload()) + targetAddr := na.TargetAddress() + stack := r.Stack() + rxNICID := r.NICID() + + isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr) + if err != nil { + // We will only get an error if rxNICID is unrecognized, + // which should not happen. For now short-circuit this + // packet. + // + // TODO(b/141002840): Handle this better? + return + } + + if isTentative { + // We just got an NA from a node that owns an address we + // are performing DAD on, implying the address is not + // unique. In this case we let the stack know so it can + // handle such a scenario and do nothing furthur with + // the NDP NA. + stack.DupTentativeAddrDetected(rxNICID, targetAddr) + return + } + + // At this point we know that the targetAddress is not tentative + // on rxNICID. However, targetAddr may still be assigned to + // rxNICID but not tentative (it could be permanent). Such a + // scenario is beyond the scope of RFC 4862. As such, we simply + // ignore such a scenario for now and proceed as normal. + // + // TODO(b/143147598): Handle the scenario described above. Also + // inform the netstack integration that a duplicate address was + // detected outside of DAD. + + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: @@ -178,14 +287,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { sent.Dropped.Increment() return } @@ -197,7 +308,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -209,8 +320,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.RouterSolicit.Increment() case header.ICMPv6RouterAdvert: + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + p := h.NDPPayload() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if len(p) < header.NDPRAMinimumSize { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + ra := header.NDPRouterAdvert(p) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + received.RouterAdvert.Increment() + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() @@ -262,13 +416,15 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: defaultIPv6HopLimit, + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 227a65cf2..335f634d5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,7 +15,6 @@ package ipv6 import ( - "fmt" "reflect" "strings" "testing" @@ -31,7 +30,7 @@ import ( ) const ( - linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") ) @@ -56,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { return nil } @@ -66,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -81,10 +80,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li } func TestICMPCounts(t *testing.T) { - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { @@ -130,7 +131,7 @@ func TestICMPCounts(t *testing.T) { {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, @@ -142,11 +143,13 @@ func TestICMPCounts(t *testing.T) { ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: r.DefaultTTL(), + HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -177,13 +180,10 @@ func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { t := v.Type() for i := 0; i < v.NumField(); i++ { v := v.Field(i) - switch v.Kind() { - case reflect.Ptr: - f(t.Field(i).Name, v.Interface().(*tcpip.StatCounter)) - case reflect.Struct: + if s, ok := v.Interface().(*tcpip.StatCounter); ok { + f(t.Field(i).Name, s) + } else { visitStats(v, f) - default: - panic(fmt.Sprintf("unexpected type %s", v.Type())) } } } @@ -206,41 +206,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab func newTestContext(t *testing.T) *testContext { c := &testContext{ - s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), - s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}), + s0: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + s1: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), } const defaultMTU = 65536 - _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) - c.linkEP0 = linkEP0 - wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} - id0 := stack.RegisterLinkEndpoint(wrappedEP0) + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { - id0 = sniffer.New(id0) + wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, id0); err != nil { + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { - t.Fatalf("AddAddress sn lladdr0: %v", err) - } - _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) - c.linkEP1 = linkEP1 - wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} - id1 := stack.RegisterLinkEndpoint(wrappedEP1) - if err := c.s1.CreateNIC(1, id1); err != nil { + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { - t.Fatalf("AddAddress sn lladdr1: %v", err) - } subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) if err != nil { @@ -279,20 +276,22 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi := <-args.src.C { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -364,3 +363,537 @@ func TestLinkResolution(t *testing.T) { routeICMPv6Packet(t, args, nil) } } + +func TestICMPChecksumValidationSimple(t *testing.T) { + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + "RouterSolicit", + header.ICMPv6RouterSolicit, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + "RouterAdvert", + header.ICMPv6RouterAdvert, + header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + "NeighborSolicit", + header.ICMPv6NeighborSolicit, + header.ICMPv6NeighborSolicitMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + "NeighborAdvert", + header.ICMPv6NeighborAdvert, + header.ICMPv6NeighborAdvertSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + "RedirectMsg", + header.ICMPv6RedirectMsg, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331a8bdaa..dd31f0fb7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -14,13 +14,15 @@ // Package ipv6 contains the implementation of the ipv6 network protocol. To use // it in the networking stack, this package must be added to the project, and -// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the -// network protocols when calling stack.New(). Then endpoints can be created -// by passing ipv6.ProtocolNumber as the network protocol number when calling +// activated on the stack by passing ipv6.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv6.ProtocolNumber as the network protocol number when calling // Stack.NewEndpoint(). package ipv6 import ( + "sync/atomic" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -28,9 +30,6 @@ import ( ) const ( - // ProtocolName is the string representation of the ipv6 protocol name. - ProtocolName = "ipv6" - // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber @@ -38,23 +37,24 @@ const ( // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff - // defaultIPv6HopLimit is the default hop limit for IPv6 Packets - // egressed by Netstack. - defaultIPv6HopLimit = 255 + // DefaultTTL is the default hop limit for IPv6 Packets egressed by + // Netstack. + DefaultTTL = 64 ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher + protocol *protocol } // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { - return 255 + return e.protocol.DefaultTTL() } // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus @@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv6 endpoint ID. @@ -97,25 +97,37 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -// WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { - length := uint16(hdr.UsedLength() + payload.Size()) +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { + length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, - NextHeader: uint8(protocol), - HopLimit: ttl, + NextHeader: uint8(params.Protocol), + HopLimit: params.TTL, + TrafficClass: params.TOS, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -123,49 +135,68 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { + if loop&stack.PacketLoop != 0 { + panic("not implemented") + } + if loop&stack.PacketOut == 0 { + return len(pkts), nil + } + + for i := range pkts { + hdr := &pkts[i].Header + size := pkts[i].DataSize + ip := e.addIPHeader(r, hdr, size, params) + pkts[i].NetworkHeader = buffer.View(ip) + } + + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // TODO(b/119580726): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} -type protocol struct{} - -// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported -// only for tests that short-circuit the stack. Regular use of the protocol is -// done via the stack, which gets a protocol descriptor from the init() function -// below. -func NewProtocol() stack.NetworkProtocol { - return &protocol{} +type protocol struct { + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 } // Number returns the ipv6 protocol number. @@ -190,25 +221,48 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, + protocol: p, }, nil } // SetOption implements NetworkProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } } // Option implements NetworkProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) } // calculateMTU calculates the network-layer payload MTU based on the link-layer @@ -221,8 +275,7 @@ func calculateMTU(mtu uint32) uint32 { return maxPayloadSize } -func init() { - stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { - return &protocol{} - }) +// NewProtocol returns an IPv6 network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{defaultTTL: DefaultTTL} } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..1cbfa7278 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,270 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveUDP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as we haven't added + // those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited + // node address of addr2/addr3 now that we have added + // added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(1, addr2); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + } + + // Should still receive a packet destined to the + // solicited node address of addr2/addr3 now that we + // have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + if err := s.RemoveAddress(1, addr3); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + } + + // Should not receive a packet destined to the solicited + // node address of addr2/addr3 yet as both of them got + // removed. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 8e4cf0e74..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -31,16 +32,18 @@ import ( func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() - s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}) - { - id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{}) - if err := s.CreateNIC(1, id); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) } + { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -95,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -107,7 +112,7 @@ func TestHopLimitValidation(t *testing.T) { {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }}, {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { @@ -148,7 +153,7 @@ func TestHopLimitValidation(t *testing.T) { // Receive the NDP packet with an invalid hop limit // value. - handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r) // Invalid count should have increased. if got := invalid.Value(); got != 1 { @@ -162,7 +167,7 @@ func TestHopLimitValidation(t *testing.T) { } // Receive the NDP packet with a valid hop limit value. - handleIPv6Payload(hdr, ndpHopLimit, ep, &r) + handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r) // Rx count of NDP packet of type typ.typ should have // increased. @@ -177,3 +182,191 @@ func TestHopLimitValidation(t *testing.T) { }) } } + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + } + }) + } +} |