diff options
Diffstat (limited to 'pkg/tcpip/network')
24 files changed, 7471 insertions, 0 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD new file mode 100644 index 000000000..6a4839fb8 --- /dev/null +++ b/pkg/tcpip/network/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "ip_test", + size = "small", + srcs = [ + "ip_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + ], +) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD new file mode 100644 index 000000000..eddf7b725 --- /dev/null +++ b/pkg/tcpip/network/arp/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "arp", + srcs = ["arp.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "arp_test", + size = "small", + srcs = ["arp_test.go"], + deps = [ + ":arp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + ], +) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go new file mode 100644 index 000000000..9d0797af7 --- /dev/null +++ b/pkg/tcpip/network/arp/arp.go @@ -0,0 +1,217 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package arp implements the ARP network protocol. It is used to resolve +// IPv4 addresses into link-local MAC addresses, and advertises IPv4 +// addresses of its stack with the local network. +// +// To use it in the networking stack, pass arp.NewProtocol() as one of the +// network protocols when calling stack.New. Then add an "arp" address to every +// NIC on the stack that should respond to ARP requests. That is: +// +// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { +// // handle err +// } +package arp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolNumber is the ARP protocol number. + ProtocolNumber = header.ARPProtocolNumber + + // ProtocolAddress is the address expected by the ARP endpoint. + ProtocolAddress = tcpip.Address("arp") +) + +// endpoint implements stack.NetworkEndpoint. +type endpoint struct { + protocol *protocol + nicID tcpip.NICID + linkEP stack.LinkEndpoint + linkAddrCache stack.LinkAddressCache +} + +// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 0 +} + +func (e *endpoint) MTU() uint32 { + lmtu := e.linkEP.MTU() + return lmtu - uint32(e.MaxHeaderLength()) +} + +func (e *endpoint) NICID() tcpip.NICID { + return e.nicID +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &stack.NetworkEndpointID{ProtocolAddress} +} + +func (e *endpoint) PrefixLen() int { + return 0 +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.ARPSize +} + +func (e *endpoint) Close() {} + +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } + h := header.ARP(v) + if !h.IsValid() { + return + } + + switch h.Op() { + case header.ARPRequest: + localAddr := tcpip.Address(h.ProtocolAddressTarget()) + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + Header: hdr, + }) + fallthrough // also fill the cache from requests + case header.ARPReply: + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + } +} + +// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. +type protocol struct { +} + +func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } +func (p *protocol) MinimumPacketSize() int { return header.ARPSize } +func (p *protocol) DefaultPrefixLen() int { return 0 } + +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.ARP(v) + return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +} + +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { + if addrWithPrefix.Address != ProtocolAddress { + return nil, tcpip.ErrBadLocalAddress + } + return &endpoint{ + protocol: p, + nicID: nicID, + linkEP: sender, + linkAddrCache: linkAddrCache, + }, nil +} + +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv4ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + r := &stack.Route{ + RemoteLinkAddress: broadcastMAC, + } + + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) + h := header.ARP(hdr.Prepend(header.ARPSize)) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), linkEP.LinkAddress()) + copy(h.ProtocolAddressSender(), localAddr) + copy(h.ProtocolAddressTarget(), addr) + + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + Header: hdr, + }) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == header.IPv4Broadcast { + return broadcastMAC, true + } + if header.IsV4MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv4Address(addr), true + } + return tcpip.LinkAddress([]byte(nil)), false +} + +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +// NewProtocol returns an ARP network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go new file mode 100644 index 000000000..1646d9cde --- /dev/null +++ b/pkg/tcpip/network/arp/arp_test.go @@ -0,0 +1,146 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp_test + +import ( + "context" + "strconv" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" +) + +const ( + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") +) + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack +} + +func newTestContext(t *testing.T) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + }) + + const defaultMTU = 65536 + ep := channel.New(256, defaultMTU, stackLinkAddr) + wep := stack.LinkEndpoint(ep) + + if testing.Verbose() { + wep = sniffer.New(ep) + } + if err := s.CreateNIC(1, wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("AddAddress for arp failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }}) + + return &testContext{ + t: t, + s: s, + linkEP: ep, + } +} + +func (c *testContext) cleanup() { + c.linkEP.Close() +} + +func TestDirectRequest(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + + const senderMAC = "\x01\x02\x03\x04\x05\x06" + const senderIPv4 = "\x0a\x00\x00\x02" + + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), senderMAC) + copy(h.ProtocolAddressSender(), senderIPv4) + + inject := func(addr tcpip.Address) { + copy(h.ProtocolAddressTarget(), addr) + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + Data: v.ToVectorisedView(), + }) + } + + for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { + t.Run(strconv.Itoa(i), func(t *testing.T) { + inject(address) + pi, _ := c.linkEP.ReadContext(context.Background()) + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) + } + rep := header.ARP(pi.Pkt.Header.View()) + if !rep.IsValid() { + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, want) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, want) + } + }) + } + + inject(stackAddrBad) + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + if pkt, ok := c.linkEP.ReadContext(ctx); ok { + t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) + } +} diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD new file mode 100644 index 000000000..d1c728ccf --- /dev/null +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -0,0 +1,45 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "reassembler_list", + out = "reassembler_list.go", + package = "fragmentation", + prefix = "reassembler", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*reassembler", + "Linker": "*reassembler", + }, +) + +go_library( + name = "fragmentation", + srcs = [ + "frag_heap.go", + "fragmentation.go", + "reassembler.go", + "reassembler_list.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + ], +) + +go_test( + name = "fragmentation_test", + size = "small", + srcs = [ + "frag_heap_test.go", + "fragmentation_test.go", + "reassembler_test.go", + ], + library = ":fragmentation", + deps = ["//pkg/tcpip/buffer"], +) diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go new file mode 100644 index 000000000..0b570d25a --- /dev/null +++ b/pkg/tcpip/network/fragmentation/frag_heap.go @@ -0,0 +1,77 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "container/heap" + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +type fragment struct { + offset uint16 + vv buffer.VectorisedView +} + +type fragHeap []fragment + +func (h *fragHeap) Len() int { + return len(*h) +} + +func (h *fragHeap) Less(i, j int) bool { + return (*h)[i].offset < (*h)[j].offset +} + +func (h *fragHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *fragHeap) Push(x interface{}) { + *h = append(*h, x.(fragment)) +} + +func (h *fragHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} + +// reassamble empties the heap and returns a VectorisedView +// containing a reassambled version of the fragments inside the heap. +func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { + curr := heap.Pop(h).(fragment) + views := curr.vv.Views() + size := curr.vv.Size() + + if curr.offset != 0 { + return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) + } + + for h.Len() > 0 { + curr := heap.Pop(h).(fragment) + if int(curr.offset) < size { + curr.vv.TrimFront(size - int(curr.offset)) + } else if int(curr.offset) > size { + return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) + } + size += curr.vv.Size() + views = append(views, curr.vv.Views()...) + } + return buffer.NewVectorisedView(size, views), nil +} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go new file mode 100644 index 000000000..9ececcb9f --- /dev/null +++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go @@ -0,0 +1,126 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "container/heap" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +var reassambleTestCases = []struct { + comment string + in []fragment + want buffer.VectorisedView +}{ + { + comment: "Non-overlapping in-order", + in: []fragment{ + {offset: 0, vv: vv(1, "0")}, + {offset: 1, vv: vv(1, "1")}, + }, + want: vv(2, "0", "1"), + }, + { + comment: "Non-overlapping out-of-order", + in: []fragment{ + {offset: 1, vv: vv(1, "1")}, + {offset: 0, vv: vv(1, "0")}, + }, + want: vv(2, "0", "1"), + }, + { + comment: "Duplicated packets", + in: []fragment{ + {offset: 0, vv: vv(1, "0")}, + {offset: 0, vv: vv(1, "0")}, + }, + want: vv(1, "0"), + }, + { + comment: "Overlapping in-order", + in: []fragment{ + {offset: 0, vv: vv(2, "01")}, + {offset: 1, vv: vv(2, "12")}, + }, + want: vv(3, "01", "2"), + }, + { + comment: "Overlapping out-of-order", + in: []fragment{ + {offset: 1, vv: vv(2, "12")}, + {offset: 0, vv: vv(2, "01")}, + }, + want: vv(3, "01", "2"), + }, + { + comment: "Overlapping subset in-order", + in: []fragment{ + {offset: 0, vv: vv(3, "012")}, + {offset: 1, vv: vv(1, "1")}, + }, + want: vv(3, "012"), + }, + { + comment: "Overlapping subset out-of-order", + in: []fragment{ + {offset: 1, vv: vv(1, "1")}, + {offset: 0, vv: vv(3, "012")}, + }, + want: vv(3, "012"), + }, +} + +func TestReassamble(t *testing.T) { + for _, c := range reassambleTestCases { + t.Run(c.comment, func(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + for _, f := range c.in { + heap.Push(&h, f) + } + got, err := h.reassemble() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, c.want) { + t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) + } + }) + } +} + +func TestReassambleFailsForNonZeroOffset(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) + _, err := h.reassemble() + if err == nil { + t.Errorf("reassemble() did not fail when the first packet had offset != 0") + } +} + +func TestReassambleFailsForHoles(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) + heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) + _, err := h.reassemble() + if err == nil { + t.Errorf("reassemble() did not fail when there was a hole in the packet") + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go new file mode 100644 index 000000000..f42abc4bb --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -0,0 +1,144 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fragmentation contains the implementation of IP fragmentation. +// It is based on RFC 791 and RFC 815. +package fragmentation + +import ( + "fmt" + "log" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. +const DefaultReassembleTimeout = 30 * time.Second + +// HighFragThreshold is the threshold at which we start trimming old +// fragmented packets. Linux uses a default value of 4 MB. See +// net.ipv4.ipfrag_high_thresh for more information. +const HighFragThreshold = 4 << 20 // 4MB + +// LowFragThreshold is the threshold we reach to when we start dropping +// older fragmented packets. It's important that we keep enough room for newer +// packets to be re-assembled. Hence, this needs to be lower than +// HighFragThreshold enough. Linux uses a default value of 3 MB. See +// net.ipv4.ipfrag_low_thresh for more information. +const LowFragThreshold = 3 << 20 // 3MB + +// Fragmentation is the main structure that other modules +// of the stack should use to implement IP Fragmentation. +type Fragmentation struct { + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[uint32]*reassembler + rList reassemblerList + size int + timeout time.Duration +} + +// NewFragmentation creates a new Fragmentation. +// +// highMemoryLimit specifies the limit on the memory consumed +// by the fragments stored by Fragmentation (overhead of internal data-structures +// is not accounted). Fragments are dropped when the limit is reached. +// +// lowMemoryLimit specifies the limit on which we will reach by dropping +// fragments after reaching highMemoryLimit. +// +// reassemblingTimeout specifies the maximum time allowed to reassemble a packet. +// Fragments are lazily evicted only when a new a packet with an +// already existing fragmentation-id arrives after the timeout. +func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { + if lowMemoryLimit >= highMemoryLimit { + lowMemoryLimit = highMemoryLimit + } + + if lowMemoryLimit < 0 { + lowMemoryLimit = 0 + } + + return &Fragmentation{ + reassemblers: make(map[uint32]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + } +} + +// Process processes an incoming fragment belonging to an ID +// and returns a complete packet when all the packets belonging to that ID have been received. +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { + f.mu.Lock() + r, ok := f.reassemblers[id] + if ok && r.tooOld(f.timeout) { + // This is very likely to be an id-collision or someone performing a slow-rate attack. + f.release(r) + ok = false + } + if !ok { + r = newReassembler(id) + f.reassemblers[id] = r + f.rList.PushFront(r) + } + f.mu.Unlock() + + res, done, consumed, err := r.process(first, last, more, vv) + if err != nil { + // We probably got an invalid sequence of fragments. Just + // discard the reassembler and move on. + f.mu.Lock() + f.release(r) + f.mu.Unlock() + return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + } + f.mu.Lock() + f.size += consumed + if done { + f.release(r) + } + // Evict reassemblers if we are consuming more memory than highLimit until + // we reach lowLimit. + if f.size > f.highLimit { + for f.size > f.lowLimit { + tail := f.rList.Back() + if tail == nil { + break + } + f.release(tail) + } + } + f.mu.Unlock() + return res, done, nil +} + +func (f *Fragmentation) release(r *reassembler) { + // Before releasing a fragment we need to check if r is already marked as done. + // Otherwise, we would delete it twice. + if r.checkDoneOrMark() { + return + } + + delete(f.reassemblers, r.id) + f.rList.Remove(r) + f.size -= r.size + if f.size < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) + f.size = 0 + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go new file mode 100644 index 000000000..72c0f53be --- /dev/null +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -0,0 +1,165 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "reflect" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// vv is a helper to build VectorisedView from different strings. +func vv(size int, pieces ...string) buffer.VectorisedView { + views := make([]buffer.View, len(pieces)) + for i, p := range pieces { + views[i] = []byte(p) + } + + return buffer.NewVectorisedView(size, views) +} + +type processInput struct { + id uint32 + first uint16 + last uint16 + more bool + vv buffer.VectorisedView +} + +type processOutput struct { + vv buffer.VectorisedView + done bool +} + +var processTestCases = []struct { + comment string + in []processInput + out []processOutput +}{ + { + comment: "One ID", + in: []processInput{ + {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, + { + comment: "Two IDs", + in: []processInput{ + {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, + {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, + {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "ab", "cd"), done: true}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, +} + +func TestFragmentationProcess(t *testing.T) { + for _, c := range processTestCases { + t.Run(c.comment, func(t *testing.T) { + f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + for i, in := range c.in { + vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + if err != nil { + t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + } + if !reflect.DeepEqual(vv, c.out[i].vv) { + t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) + } + if done != c.out[i].done { + t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) + } + if c.out[i].done { + if _, ok := f.reassemblers[in.id]; ok { + t.Errorf("Process(%d) did not remove buffer from reassemblers", i) + } + for n := f.rList.Front(); n != nil; n = n.Next() { + if n.id == in.id { + t.Errorf("Process(%d) did not remove buffer from rList", i) + } + } + } + } + }) + } +} + +func TestReassemblingTimeout(t *testing.T) { + timeout := time.Millisecond + f := NewFragmentation(1024, 512, timeout) + // Send first fragment with id = 0, first = 0, last = 0, and more = true. + f.Process(0, 0, 0, true, vv(1, "0")) + // Sleep more than the timeout. + time.Sleep(2 * timeout) + // Send another fragment that completes a packet. + // However, no packet should be reassembled because the fragment arrived after the timeout. + _, done, err := f.Process(0, 1, 1, false, vv(1, "1")) + if err != nil { + t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + } + if done { + t.Errorf("Fragmentation does not respect the reassembling timeout.") + } +} + +func TestMemoryLimits(t *testing.T) { + f := NewFragmentation(3, 1, DefaultReassembleTimeout) + // Send first fragment with id = 0. + f.Process(0, 0, 0, true, vv(1, "0")) + // Send first fragment with id = 1. + f.Process(1, 0, 0, true, vv(1, "1")) + // Send first fragment with id = 2. + f.Process(2, 0, 0, true, vv(1, "2")) + + // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be + // evicted. + f.Process(3, 0, 0, true, vv(1, "3")) + + if _, ok := f.reassemblers[0]; ok { + t.Errorf("Memory limits are not respected: id=0 has not been evicted.") + } + if _, ok := f.reassemblers[1]; ok { + t.Errorf("Memory limits are not respected: id=1 has not been evicted.") + } + if _, ok := f.reassemblers[3]; !ok { + t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") + } +} + +func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { + f := NewFragmentation(1, 0, DefaultReassembleTimeout) + // Send first fragment with id = 0. + f.Process(0, 0, 0, true, vv(1, "0")) + // Send the same packet again. + f.Process(0, 0, 0, true, vv(1, "0")) + + got := f.size + want := 1 + if got != want { + t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go new file mode 100644 index 000000000..0a83d81f2 --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -0,0 +1,118 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "container/heap" + "fmt" + "math" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +type hole struct { + first uint16 + last uint16 + deleted bool +} + +type reassembler struct { + reassemblerEntry + id uint32 + size int + mu sync.Mutex + holes []hole + deleted int + heap fragHeap + done bool + creationTime time.Time +} + +func newReassembler(id uint32) *reassembler { + r := &reassembler{ + id: id, + holes: make([]hole, 0, 16), + deleted: 0, + heap: make(fragHeap, 0, 8), + creationTime: time.Now(), + } + r.holes = append(r.holes, hole{ + first: 0, + last: math.MaxUint16, + deleted: false}) + return r +} + +// updateHoles updates the list of holes for an incoming fragment and +// returns true iff the fragment filled at least part of an existing hole. +func (r *reassembler) updateHoles(first, last uint16, more bool) bool { + used := false + for i := range r.holes { + if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { + continue + } + used = true + r.deleted++ + r.holes[i].deleted = true + if first > r.holes[i].first { + r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) + } + if last < r.holes[i].last && more { + r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) + } + } + return used +} + +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { + r.mu.Lock() + defer r.mu.Unlock() + consumed := 0 + if r.done { + // A concurrent goroutine might have already reassembled + // the packet and emptied the heap while this goroutine + // was waiting on the mutex. We don't have to do anything in this case. + return buffer.VectorisedView{}, false, consumed, nil + } + if r.updateHoles(first, last, more) { + // We store the incoming packet only if it filled some holes. + heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) + consumed = vv.Size() + r.size += consumed + } + // Check if all the holes have been deleted and we are ready to reassamble. + if r.deleted < len(r.holes) { + return buffer.VectorisedView{}, false, consumed, nil + } + res, err := r.heap.reassemble() + if err != nil { + return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) + } + return res, true, consumed, nil +} + +func (r *reassembler) tooOld(timeout time.Duration) bool { + return time.Now().Sub(r.creationTime) > timeout +} + +func (r *reassembler) checkDoneOrMark() bool { + r.mu.Lock() + prev := r.done + r.done = true + r.mu.Unlock() + return prev +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go new file mode 100644 index 000000000..7eee0710d --- /dev/null +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -0,0 +1,105 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "math" + "reflect" + "testing" +) + +type updateHolesInput struct { + first uint16 + last uint16 + more bool +} + +var holesTestCases = []struct { + comment string + in []updateHolesInput + want []hole +}{ + { + comment: "No fragments. Expected holes: {[0 -> inf]}.", + in: []updateHolesInput{}, + want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, + }, + { + comment: "One fragment at beginning. Expected holes: {[2, inf]}.", + in: []updateHolesInput{{first: 0, last: 1, more: true}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 2, last: math.MaxUint16, deleted: false}, + }, + }, + { + comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", + in: []updateHolesInput{{first: 1, last: 2, more: true}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 0, last: 0, deleted: false}, + {first: 3, last: math.MaxUint16, deleted: false}, + }, + }, + { + comment: "One fragment at the end. Expected holes: {[0, 0]}.", + in: []updateHolesInput{{first: 1, last: 2, more: false}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 0, last: 0, deleted: false}, + }, + }, + { + comment: "One fragment completing a packet. Expected holes: {}.", + in: []updateHolesInput{{first: 0, last: 1, more: false}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + }, + }, + { + comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", + in: []updateHolesInput{ + {first: 0, last: 1, more: true}, + {first: 2, last: 3, more: false}, + }, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 2, last: math.MaxUint16, deleted: true}, + }, + }, + { + comment: "Two overlapping fragments completing a packet. Expected holes: {}.", + in: []updateHolesInput{ + {first: 0, last: 2, more: true}, + {first: 2, last: 3, more: false}, + }, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 3, last: math.MaxUint16, deleted: true}, + }, + }, +} + +func TestUpdateHoles(t *testing.T) { + for _, c := range holesTestCases { + r := newReassembler(0) + for _, i := range c.in { + r.updateHoles(i.first, i.last, i.more) + } + if !reflect.DeepEqual(r.holes, c.want) { + t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) + } + } +} diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD new file mode 100644 index 000000000..872165866 --- /dev/null +++ b/pkg/tcpip/network/hash/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "hash", + srcs = ["hash.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/rand", + "//pkg/tcpip/header", + ], +) diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go new file mode 100644 index 000000000..8f65713c5 --- /dev/null +++ b/pkg/tcpip/network/hash/hash.go @@ -0,0 +1,93 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package hash contains utility functions for hashing. +package hash + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +var hashIV = RandN32(1)[0] + +// RandN32 generates a slice of n cryptographic random 32-bit numbers. +func RandN32(n int) []uint32 { + b := make([]byte, 4*n) + if _, err := rand.Read(b); err != nil { + panic("unable to get random numbers: " + err.Error()) + } + r := make([]uint32, n) + for i := range r { + r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)]) + } + return r +} + +// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted +// from linux. +func Hash3Words(a, b, c, initval uint32) uint32 { + const iv = 0xdeadbeef + (3 << 2) + initval += iv + + a += initval + b += initval + c += initval + + c ^= b + c -= rol32(b, 14) + a ^= c + a -= rol32(c, 11) + b ^= a + b -= rol32(a, 25) + c ^= b + c -= rol32(b, 16) + a ^= c + a -= rol32(c, 4) + b ^= a + b -= rol32(a, 14) + c ^= b + c -= rol32(b, 24) + + return c +} + +// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791. +func IPv4FragmentHash(h header.IPv4) uint32 { + x := uint32(h.ID())<<16 | uint32(h.Protocol()) + t := h.SourceAddress() + y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = h.DestinationAddress() + z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return Hash3Words(x, y, z, hashIV) +} + +// IPv6FragmentHash computes the hash of the ipv6 fragment. +// Unlike IPv4, the protocol is not used to compute the hash. +// RFC 2640 (sec 4.5) is not very sharp on this aspect. +// As a reference, also Linux ignores the protocol to compute +// the hash (inet6_hash_frag). +func IPv6FragmentHash(h header.IPv6, id uint32) uint32 { + t := h.SourceAddress() + y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = h.DestinationAddress() + z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return Hash3Words(id, y, z, hashIV) +} + +func rol32(v, shift uint32) uint32 { + return (v << shift) | (v >> ((-shift) & 31)) +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go new file mode 100644 index 000000000..4c20301c6 --- /dev/null +++ b/pkg/tcpip/network/ip_test.go @@ -0,0 +1,655 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +const ( + localIpv4Addr = "\x0a\x00\x00\x01" + localIpv4PrefixLen = 24 + remoteIpv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + localIpv6PrefixLen = 120 + remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" +) + +// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. +// The former is used to pretend that it's a link endpoint so that we can +// inspect packets written by the network endpoints. The latter is used to +// pretend that it's the network stack so that it can inspect incoming packets +// that have been handled by the network endpoints. +// +// Packets are checked by comparing their fields/values against the expected +// values stored in the test object itself. +type testObject struct { + t *testing.T + protocol tcpip.TransportProtocolNumber + contents []byte + srcAddr tcpip.Address + dstAddr tcpip.Address + v4 bool + typ stack.ControlType + extra uint32 + + dataCalls int + controlCalls int +} + +// checkValues verifies that the transport protocol, data contents, src & dst +// addresses of a packet match what's expected. If any field doesn't match, the +// test fails. +func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { + v := vv.ToView() + if protocol != t.protocol { + t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) + } + + if srcAddr != t.srcAddr { + t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) + } + + if dstAddr != t.dstAddr { + t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) + } + + if len(v) != len(t.contents) { + t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) + } + + for i := range t.contents { + if t.contents[i] != v[i] { + t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) + } + } +} + +// DeliverTransportPacket is called by network endpoints after parsing incoming +// packets. This is used by the test object to verify that the results of the +// parsing are expected. +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) + t.dataCalls++ +} + +// DeliverTransportControlPacket is called by network endpoints after parsing +// incoming control (ICMP) packets. This is used by the test object to verify +// that the results of the parsing are expected. +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) + if typ != t.typ { + t.t.Errorf("typ = %v, want %v", typ, t.typ) + } + if extra != t.extra { + t.t.Errorf("extra = %v, want %v", extra, t.extra) + } + t.controlCalls++ +} + +// Attach is only implemented to satisfy the LinkEndpoint interface. +func (*testObject) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (*testObject) IsAttached() bool { + return true +} + +// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that +// matches the linux loopback MTU. +func (*testObject) MTU() uint32 { + return 65536 +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*testObject) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. +func (*testObject) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testObject) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*testObject) Wait() {} + +// WritePacket is called by network endpoints after producing a packet and +// writing it to the link endpoint. This is used by the test object to verify +// that the produced packet is as expected. +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { + var prot tcpip.TransportProtocolNumber + var srcAddr tcpip.Address + var dstAddr tcpip.Address + + if t.v4 { + h := header.IPv4(pkt.Header.View()) + prot = tcpip.TransportProtocolNumber(h.Protocol()) + srcAddr = h.SourceAddress() + dstAddr = h.DestinationAddress() + + } else { + h := header.IPv6(pkt.Header.View()) + prot = tcpip.TransportProtocolNumber(h.NextHeader()) + srcAddr = h.SourceAddress() + dstAddr = h.DestinationAddress() + } + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) + s.CreateNIC(1, loopback.New()) + s.AddAddress(1, ipv4.ProtocolNumber, local) + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + Gateway: ipv4Gateway, + NIC: 1, + }}) + + return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) +} + +func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) + s.CreateNIC(1, loopback.New()) + s.AddAddress(1, ipv6.ProtocolNumber, local) + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: ipv6Gateway, + NIC: 1, + }}) + + return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) +} + +func buildDummyStack() *stack.Stack { + return stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) +} + +func TestIPv4Send(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Allocate and initialize the payload view. + payload := buffer.NewView(100) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + + // Allocate the header buffer. + hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + + // Issue the write. + o.protocol = 123 + o.srcAddr = localIpv4Addr + o.dstAddr = remoteIpv4Addr + o.contents = payload + + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } +} + +func TestIPv4Receive(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv4MinimumSize + 30 + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, + }) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr + o.contents = view[header.IPv4MinimumSize:totalLen] + + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: view.ToVectorisedView(), + }) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv4ReceiveControl(t *testing.T) { + const mtu = 0xbeef - header.IPv4MinimumSize + cases := []struct { + name string + expectedCount int + fragmentOffset uint16 + code uint8 + expectedTyp stack.ControlType + expectedExtra uint32 + trunc int + }{ + {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, + {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, + {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, + {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, + {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, + } + r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") + if err != nil { + t.Fatal(err) + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + o := testObject{t: t} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + view := buffer.NewView(dataOffset + 8) + + // Create the outer IPv4 header. + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(view) - c.trunc), + TTL: 20, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: "\x0a\x00\x00\xbb", + DstAddr: localIpv4Addr, + }) + + // Create the ICMP header. + icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) + icmp.SetType(header.ICMPv4DstUnreachable) + icmp.SetCode(c.code) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) + + // Create the inner IPv4 header. + ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: 100, + TTL: 20, + Protocol: 10, + FragmentOffset: c.fragmentOffset, + SrcAddr: localIpv4Addr, + DstAddr: remoteIpv4Addr, + }) + + // Make payload be non-zero. + for i := dataOffset; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to IPv4 endpoint, dispatcher will validate that + // it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr + o.contents = view[dataOffset:] + o.typ = c.expectedTyp + o.extra = c.expectedExtra + + vv := view[:len(view)-c.trunc].ToVectorisedView() + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: vv, + }) + if want := c.expectedCount; o.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + } + }) + } +} + +func TestIPv4FragmentationReceive(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv4MinimumSize + 24 + + frag1 := buffer.NewView(totalLen) + ip1 := header.IPv4(frag1) + ip1.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + FragmentOffset: 0, + Flags: header.IPv4FlagMoreFragments, + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, + }) + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + frag1[i] = uint8(i) + } + + frag2 := buffer.NewView(totalLen) + ip2 := header.IPv4(frag2) + ip2.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + FragmentOffset: 24, + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, + }) + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + frag2[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr + o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + + // Send first segment. + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) + if o.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + } + + // Send second segment. + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv6Send(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Allocate and initialize the payload view. + payload := buffer.NewView(100) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + + // Allocate the header buffer. + hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + + // Issue the write. + o.protocol = 123 + o.srcAddr = localIpv6Addr + o.dstAddr = remoteIpv6Addr + o.contents = payload + + r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } +} + +func TestIPv6Receive(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv6MinimumSize + 30 + view := buffer.NewView(totalLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(totalLen - header.IPv6MinimumSize), + NextHeader: 10, + HopLimit: 20, + SrcAddr: remoteIpv6Addr, + DstAddr: localIpv6Addr, + }) + + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < totalLen; i++ { + view[i] = uint8(i) + } + + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv6Addr + o.dstAddr = localIpv6Addr + o.contents = view[header.IPv6MinimumSize:totalLen] + + r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: view.ToVectorisedView(), + }) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv6ReceiveControl(t *testing.T) { + newUint16 := func(v uint16) *uint16 { return &v } + + const mtu = 0xffff + const outerSrcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa" + cases := []struct { + name string + expectedCount int + fragmentOffset *uint16 + typ header.ICMPv6Type + code uint8 + expectedTyp stack.ControlType + expectedExtra uint32 + trunc int + }{ + {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, + {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, + {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, + {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, + {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, + {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, + } + r, err := buildIPv6Route( + localIpv6Addr, + "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + ) + if err != nil { + t.Fatal(err) + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + defer ep.Close() + + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + if c.fragmentOffset != nil { + dataOffset += header.IPv6FragmentHeaderSize + } + view := buffer.NewView(dataOffset + 8) + + // Create the outer IPv6 header. + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 20, + SrcAddr: outerSrcAddr, + DstAddr: localIpv6Addr, + }) + + // Create the ICMP header. + icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) + icmp.SetType(c.typ) + icmp.SetCode(c.code) + icmp.SetIdent(0xdead) + icmp.SetSequence(0xbeef) + + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + NextHeader: 10, + HopLimit: 20, + SrcAddr: localIpv6Addr, + DstAddr: remoteIpv6Addr, + }) + + // Build the fragmentation header if needed. + if c.fragmentOffset != nil { + ip.SetNextHeader(header.IPv6FragmentHeader) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) + frag.Encode(&header.IPv6FragmentFields{ + NextHeader: 10, + FragmentOffset: *c.fragmentOffset, + M: true, + Identification: 0x12345678, + }) + } + + // Make payload be non-zero. + for i := dataOffset; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to IPv6 endpoint, dispatcher will validate that + // it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv6Addr + o.dstAddr = localIpv6Addr + o.contents = view[dataOffset:] + o.typ = c.expectedTyp + o.extra = c.expectedExtra + + // Set ICMPv6 checksum. + icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) + + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) + if want := c.expectedCount; o.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD new file mode 100644 index 000000000..880ea7de2 --- /dev/null +++ b/pkg/tcpip/network/ipv4/BUILD @@ -0,0 +1,38 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ipv4", + srcs = [ + "icmp.go", + "ipv4.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "ipv4_test", + size = "small", + srcs = ["ipv4_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go new file mode 100644 index 000000000..4cbefe5ab --- /dev/null +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -0,0 +1,160 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) + + // We don't use IsValid() here because ICMP only requires that the IP + // header plus 8 bytes of the transport header be included. So it's + // likely that it is truncated, which would cause IsValid to return + // false. + // + // Drop packet if it doesn't have the basic IPv4 header or if the + // original source address doesn't match the endpoint's address. + if hdr.SourceAddress() != e.id.LocalAddress { + return + } + + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + // We won't be able to handle this if it doesn't contain the + // full IPv4 header, or if it's a fragment not at offset 0 + // (because it won't have the transport header). + return + } + + // Skip the ip header, then deliver control message. + pkt.Data.TrimFront(hlen) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) +} + +func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { + stats := r.Stats() + received := stats.ICMP.V4PacketsReceived + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.ICMPv4(v) + + // TODO(b/112892170): Meaningfully handle all ICMP types. + switch h.Type() { + case header.ICMPv4Echo: + received.Echo.Increment() + + // Only send a reply if the checksum is valid. + wantChecksum := h.Checksum() + // Reset the checksum field to 0 to can calculate the proper + // checksum. We'll have to reset this before we hand the packet + // off. + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + if gotChecksum != wantChecksum { + // It's possible that a raw socket expects to receive this. + h.SetChecksum(wantChecksum) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + received.Invalid.Increment() + return + } + + // It's possible that a raw socket expects to receive this. + h.SetChecksum(wantChecksum) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) + + vv := pkt.Data.Clone(nil) + vv.TrimFront(header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + copy(pkt, h) + pkt.SetType(header.ICMPv4EchoReply) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + sent := stats.ICMP.V4PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { + sent.Dropped.Increment() + return + } + sent.EchoReply.Increment() + + case header.ICMPv4EchoReply: + received.EchoReply.Increment() + + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + + case header.ICMPv4DstUnreachable: + received.DstUnreachable.Increment() + + pkt.Data.TrimFront(header.ICMPv4MinimumSize) + switch h.Code() { + case header.ICMPv4PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, pkt) + + case header.ICMPv4FragmentationNeeded: + mtu := uint32(h.MTU()) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + } + + case header.ICMPv4SrcQuench: + received.SrcQuench.Increment() + + case header.ICMPv4Redirect: + received.Redirect.Increment() + + case header.ICMPv4TimeExceeded: + received.TimeExceeded.Increment() + + case header.ICMPv4ParamProblem: + received.ParamProblem.Increment() + + case header.ICMPv4Timestamp: + received.Timestamp.Increment() + + case header.ICMPv4TimestampReply: + received.TimestampReply.Increment() + + case header.ICMPv4InfoRequest: + received.InfoRequest.Increment() + + case header.ICMPv4InfoReply: + received.InfoReply.Increment() + + default: + received.Invalid.Increment() + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go new file mode 100644 index 000000000..9db42b2a4 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -0,0 +1,598 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ipv4 contains the implementation of the ipv4 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv4.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv4.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv4 + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/hash" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolNumber is the ipv4 protocol number. + ProtocolNumber = header.IPv4ProtocolNumber + + // MaxTotalSize is maximum size that can be encoded in the 16-bit + // TotalLength field of the ipv4 header. + MaxTotalSize = 0xffff + + // DefaultTTL is the default time-to-live value for this endpoint. + DefaultTTL = 64 + + // buckets is the number of identifier buckets. + buckets = 2048 +) + +type endpoint struct { + nicID tcpip.NICID + id stack.NetworkEndpointID + prefixLen int + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher + fragmentation *fragmentation.Fragmentation + protocol *protocol + stack *stack.Stack +} + +// NewEndpoint creates a new ipv4 endpoint. +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { + e := &endpoint{ + nicID: nicID, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + dispatcher: dispatcher, + fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, + stack: st, + } + + return e, nil +} + +// DefaultTTL is the default time-to-live value for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return e.protocol.DefaultTTL() +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// ID returns the ipv4 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// PrefixLen returns the ipv4 endpoint subnet prefix length in bits. +func (e *endpoint) PrefixLen() int { + return e.prefixLen +} + +// MaxHeaderLength returns the maximum length needed by ipv4 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize +} + +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + +// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { + // This packet is too big, it needs to be fragmented. + ip := header.IPv4(pkt.Header.View()) + flags := ip.Flags() + + // Update mtu to take into account the header, which will exist in all + // fragments anyway. + innerMTU := mtu - int(ip.HeaderLength()) + + // Round the MTU down to align to 8 bytes. Then calculate the number of + // fragments. Calculate fragment sizes as in RFC791. + innerMTU &^= 7 + n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU + + outerMTU := innerMTU + int(ip.HeaderLength()) + offset := ip.FragmentOffset() + originalAvailableLength := pkt.Header.AvailableLength() + for i := 0; i < n; i++ { + // Where possible, the first fragment that is sent has the same + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + h := ip + if i > 0 { + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) + copy(h, ip[:ip.HeaderLength()]) + } + if i != n-1 { + h.SetTotalLength(uint16(outerMTU)) + h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) + } else { + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) + h.SetFlagsFragmentOffset(flags, offset) + } + h.SetChecksum(0) + h.SetChecksum(^h.CalculateChecksum()) + offset += uint16(innerMTU) + if i > 0 { + newPayload := pkt.Data.Clone(nil) + newPayload.CapLength(innerMTU) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + pkt.Data.TrimFront(newPayload.Size()) + continue + } + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() + newPayload.CapLength(newPayloadLength) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + pkt.Data.TrimFront(newPayloadLength) + } else { + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) + emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] + tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) + tmp.Append(pkt.Data) + pkt.Data = tmp + } + } + return nil +} + +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + length := uint16(hdr.UsedLength() + payloadSize) + id := uint32(0) + if length > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + } + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: length, + ID: uint16(id), + TTL: params.TTL, + TOS: params.TOS, + Protocol: uint8(params.Protocol), + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { + // iptables is telling us to drop the packet. + return nil + } + + if pkt.NatDone { + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + return nil + } + } + + if r.Loop&stack.PacketLoop != 0 { + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + loopedR := r.MakeLoopedRoute() + + e.HandlePacket(&loopedR, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + + loopedR.Release() + } + if r.Loop&stack.PacketOut == 0 { + return nil + } + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + return err + } + r.Stats().IP.PacketsSent.Increment() + return nil +} + +// WritePackets implements stack.NetworkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { + panic("multiple packets in local loop") + } + if r.Loop&stack.PacketOut == 0 { + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() + } + + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ + } + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, nil +} + +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { + // The packet already has an IP header, but there are a few required + // checks. + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) + if !ip.IsValid(pkt.Data.Size()) { + return tcpip.ErrInvalidOptionValue + } + + // Always set the total length. + ip.SetTotalLength(uint16(pkt.Data.Size())) + + // Set the source address when zero. + if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, + // it will be part of the route. + ip.SetDestinationAddress(r.RemoteAddress) + + // Set the packet ID when zero. + if ip.ID() == 0 { + id := uint32(0) + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + } + ip.SetID(uint16(id)) + } + + // Always set the checksum. + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + if r.Loop&stack.PacketLoop != 0 { + e.HandlePacket(r, pkt.Clone()) + } + if r.Loop&stack.PacketOut == 0 { + return nil + } + + r.Stats().IP.PacketsSent.Increment() + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) +} + +// HandlePacket is called by the link layer when new ipv4 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + h := header.IPv4(headerView) + if !h.IsValid(pkt.Data.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + pkt.NetworkHeader = headerView[:h.HeaderLength()] + + hlen := int(h.HeaderLength()) + tlen := int(h.TotalLength()) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) + + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { + // iptables is telling us to drop the packet. + return + } + + more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 + if more || h.FragmentOffset() != 0 { + if pkt.Data.Size() == 0 { + // Drop the packet as it's marked as a fragment but has + // no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + // The packet is a fragment, let's try to reassemble it. + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < h.FragmentOffset() { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + var ready bool + var err error + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + if !ready { + return + } + } + p := h.TransportProtocol() + if p == header.ICMPv4ProtocolNumber { + headerView.CapLength(hlen) + e.handleICMP(r, pkt) + return + } + r.Stats().IP.PacketsDelivered.Increment() + e.dispatcher.DeliverTransportPacket(r, p, pkt) +} + +// Close cleans up resources associated with the endpoint. +func (e *endpoint) Close() {} + +type protocol struct { + ids []uint32 + hashIV uint32 + + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 +} + +// Number returns the ipv4 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv4 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv4MinimumSize +} + +// DefaultPrefixLen returns the IPv4 default prefix length. +func (p *protocol) DefaultPrefixLen() int { + return header.IPv4AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv4(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) +} + +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + if mtu > MaxTotalSize { + mtu = MaxTotalSize + } + return mtu - header.IPv4MinimumSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address, the transport protocol number, and a random initial +// value (generated once on initialization) to generate the hash. +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + t := r.LocalAddress + a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = r.RemoteAddress + b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return hash.Hash3Words(a, b, uint32(protocol), hashIV) +} + +// NewProtocol returns an IPv4 network protocol. +func NewProtocol() stack.NetworkProtocol { + ids := make([]uint32, buckets) + + // Randomly initialize hashIV and the ids. + r := hash.RandN32(1 + buckets) + for i := range ids { + ids[i] = r[i] + } + hashIV := r[buckets] + + return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go new file mode 100644 index 000000000..5a864d832 --- /dev/null +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -0,0 +1,475 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "bytes" + "encoding/hex" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +func TestExcludeBroadcast(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + + const defaultMTU = 65536 + ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) + if testing.Verbose() { + ep = sniffer.New(ep) + } + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }}) + + randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} + + var wq waiter.Queue + t.Run("WithoutPrimaryAddress", func(t *testing.T) { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + // Cannot connect using a broadcast address as the source. + if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + } + + // However, we can bind to a broadcast address to listen. + if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { + t.Errorf("Bind failed: %v", err) + } + }) + + t.Run("WithPrimaryAddress", func(t *testing.T) { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + // Add a valid primary endpoint address, now we can connect. + if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + if err := ep.Connect(randomAddr); err != nil { + t.Errorf("Connect failed: %v", err) + } + }) +} + +// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// data should already be in the header before WritePacket. extraLength +// indicates how much extra space should be in the header. The payload is made +// from many Views of the sizes listed in viewSizes. +func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { + hdr := buffer.NewPrependable(hdrLength + extraLength) + hdr.Prepend(hdrLength) + rand.Read(hdr.View()) + + var views []buffer.View + totalLength := 0 + for _, s := range viewSizes { + newView := buffer.NewView(s) + rand.Read(newView) + views = append(views, newView) + totalLength += s + } + payload := buffer.NewVectorisedView(totalLength, views) + return hdr, payload +} + +// comparePayloads compared the contents of all the packets against the contents +// of the source packet. +func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { + t.Helper() + // Make a complete array of the sourcePacketInfo packet. + source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) + source = append(source, sourcePacketInfo.Header.View()...) + source = append(source, sourcePacketInfo.Data.ToView()...) + + // Make a copy of the IP header, which will be modified in some fields to make + // an expected header. + sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) + sourceCopy.SetChecksum(0) + sourceCopy.SetFlagsFragmentOffset(0, 0) + sourceCopy.SetTotalLength(0) + var offset uint16 + // Build up an array of the bytes sent. + var reassembledPayload []byte + for i, packet := range packets { + // Confirm that the packet is valid. + allBytes := packet.Header.View().ToVectorisedView() + allBytes.Append(packet.Data) + ip := header.IPv4(allBytes.ToView()) + if !ip.IsValid(len(ip)) { + t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + } + if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { + t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + } + if got, want := len(ip), int(mtu); got > want { + t.Errorf("fragment is too large, got %d want %d", got, want) + } + if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } + if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + } + if i < len(packets)-1 { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) + } else { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) + } + reassembledPayload = append(reassembledPayload, ip.Payload()...) + offset += ip.TotalLength() - uint16(ip.HeaderLength()) + // Clear out the checksum and length from the ip because we can't compare + // it. + sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetChecksum(0) + sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) + if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { + t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + } + } + expected := source[source.HeaderLength():] + if !bytes.Equal(reassembledPayload, expected) { + t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + } +} + +type errorChannel struct { + *channel.Endpoint + Ch chan stack.PacketBuffer + packetCollectorErrors []*tcpip.Error +} + +// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket +// will return successive errors from packetCollectorErrors until the list is +// empty and then return nil each time. +func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { + return &errorChannel{ + Endpoint: channel.New(size, mtu, linkAddr), + Ch: make(chan stack.PacketBuffer, size), + packetCollectorErrors: packetCollectorErrors, + } +} + +// Drain removes all outbound packets from the channel and counts them. +func (e *errorChannel) Drain() int { + c := 0 + for { + select { + case <-e.Ch: + c++ + default: + return c + } + } +} + +// WritePacket stores outbound packets into the channel. +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { + select { + case e.Ch <- pkt: + default: + } + + nextError := (*tcpip.Error)(nil) + if len(e.packetCollectorErrors) > 0 { + nextError = e.packetCollectorErrors[0] + e.packetCollectorErrors = e.packetCollectorErrors[1:] + } + return nextError +} + +type context struct { + stack.Route + linkEP *errorChannel +} + +func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { + // Make the packet and write it. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) + s.CreateNIC(1, ep) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + r, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return context{ + Route: r, + linkEP: ep, + } +} + +func TestFragmentation(t *testing.T) { + var manyPayloadViewsSizes [1000]int + for i := range manyPayloadViewsSizes { + manyPayloadViewsSizes[i] = 7 + } + fragTests := []struct { + description string + mtu uint32 + gso *stack.GSO + hdrLength int + extraLength int + payloadViewsSizes []int + expectedFrags int + }{ + {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, + {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, + {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, + {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, + {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, + {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, + {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, + {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, + {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, + } + + for _, ft := range fragTests { + t.Run(ft.description, func(t *testing.T) { + hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := stack.PacketBuffer{ + Header: hdr, + // Save the source payload because WritePacket will modify it. + Data: payload.Clone(nil), + } + c := buildContext(t, nil, ft.mtu) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload, + }) + if err != nil { + t.Errorf("err got %v, want %v", err, nil) + } + + var results []stack.PacketBuffer + L: + for { + select { + case pi := <-c.linkEP.Ch: + results = append(results, pi) + default: + break L + } + } + + if got, want := len(results), ft.expectedFrags; got != want { + t.Errorf("len(result) got %d, want %d", got, want) + } + if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet len(result) got %d, want %d", got, want) + } + compareFragments(t, results, source, ft.mtu) + }) + } +} + +// TestFragmentationErrors checks that errors are returned from write packet +// correctly. +func TestFragmentationErrors(t *testing.T) { + fragTests := []struct { + description string + mtu uint32 + hdrLength int + payloadViewsSizes []int + packetCollectorErrors []*tcpip.Error + }{ + {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, + {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, + {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, + {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + } + + for _, ft := range fragTests { + t.Run(ft.description, func(t *testing.T) { + hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + c := buildContext(t, ft.packetCollectorErrors, ft.mtu) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: payload, + }) + for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { + if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { + t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) + } + } + // We only need to check that last error because all the ones before are + // nil. + if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { + t.Errorf("err got %v, want %v", got, want) + } + if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { + t.Errorf("after linkEP error len(result) got %d, want %d", got, want) + } + }) + } +} + +func TestInvalidFragments(t *testing.T) { + // These packets have both IHL and TotalLength set to 0. + testCases := []struct { + name string + packets [][]byte + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + "ihl_totallen_zero_valid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + "ihl_totallen_zero_invalid_frag_offset", + [][]byte{ + {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 0, + }, + { + // Total Length of 37(20 bytes IP header + 17 bytes of + // payload) + // Frag Offset of 0x1ffe = 8190*8 = 65520 + // Leading to the fragment end to be past 65535. + "ihl_totallen_valid_invalid_frag_offset_1", + [][]byte{ + {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + // The following 3 tests were found by running a fuzzer and were + // triggering a panic in the IPv4 reassembler code. + { + "ihl_less_than_ipv4_minimum_size_1", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_2", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "ihl_less_than_ipv4_minimum_size_3", + [][]byte{ + {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 2, + 0, + }, + { + "fragment_with_short_total_len_extra_payload", + [][]byte{ + {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + }, + 1, + 1, + }, + { + "multiple_fragments_with_more_fragments_set_to_false", + [][]byte{ + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }, + 1, + 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + const nicID tcpip.NICID = 42 + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + }, + }) + + var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) + var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) + ep := channel.New(10, 1500, linkAddr) + s.CreateNIC(nicID, sniffer.New(ep)) + + for _, pkt := range tc.packets { + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD new file mode 100644 index 000000000..3f71fc520 --- /dev/null +++ b/pkg/tcpip/network/ipv6/BUILD @@ -0,0 +1,44 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ipv6", + srcs = [ + "icmp.go", + "ipv6.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/network/fragmentation", + "//pkg/tcpip/network/hash", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "ipv6_test", + size = "small", + srcs = [ + "icmp_test.go", + "ipv6_test.go", + "ndp_test.go", + ], + library = ":ipv6", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go new file mode 100644 index 000000000..bdf3a0d25 --- /dev/null +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -0,0 +1,546 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) + + // We don't use IsValid() here because ICMP only requires that up to + // 1280 bytes of the original packet be included. So it's likely that it + // is truncated, which would cause IsValid to return false. + // + // Drop packet if it doesn't have the basic IPv6 header or if the + // original source address doesn't match the endpoint's address. + if hdr.SourceAddress() != e.id.LocalAddress { + return + } + + // Skip the IP header, then handle the fragmentation header if there + // is one. + pkt.Data.TrimFront(header.IPv6MinimumSize) + p := hdr.TransportProtocol() + if p == header.IPv6FragmentHeader { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + // We can't handle fragments that aren't at offset 0 + // because they don't have the transport headers. + return + } + + // Skip fragmentation header and find out the actual protocol + // number. + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) + p = fragHdr.TransportProtocol() + } + + // Deliver the control packet to the transport endpoint. + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) +} + +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { + stats := r.Stats().ICMP + sent := stats.V6PacketsSent + received := stats.V6PacketsReceived + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.ICMPv6(v) + iph := header.IPv6(netHeader) + + // Validate ICMPv6 checksum before processing the packet. + // + // This copy is used as extra payload during the checksum calculation. + payload := pkt.Data.Clone(nil) + payload.TrimFront(len(h)) + if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + received.Invalid.Increment() + return + } + + isNDPValid := func() bool { + // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and + // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. + // + // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the + // packet includes a fragmentation header. + return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0 + } + + // TODO(b/112892170): Meaningfully handle all ICMP types. + switch h.Type() { + case header.ICMPv6PacketTooBig: + received.PacketTooBig.Increment() + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + mtu := header.ICMPv6(hdr).MTU() + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + + case header.ICMPv6DstUnreachable: + received.DstUnreachable.Increment() + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + switch header.ICMPv6(hdr).Code() { + case header.ICMPv6PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, pkt) + } + + case header.ICMPv6NeighborSolicit: + received.NeighborSolicit.Increment() + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + received.Invalid.Increment() + return + } + + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) + it, err := ns.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NS option, drop the packet. + received.Invalid.Increment() + return + } + + targetAddr := ns.TargetAddress() + s := r.Stack() + if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now, drop this packet. + // + // TODO(b/141002840): Handle this better? + return + } else if isTentative { + // If the target address is tentative and the source of the packet is a + // unicast (specified) address, then the source of the packet is + // attempting to perform address resolution on the target. In this case, + // the solicitation is silently ignored, as per RFC 4862 section 5.4.3. + // + // If the target address is tentative and the source of the packet is the + // unspecified address (::), then we know another node is also performing + // DAD for the same address (since the target address is tentative for us, + // we know we are also performing DAD on it). In this case we let the + // stack know so it can handle such a scenario and do nothing further with + // the NS. + if r.RemoteAddress == header.IPv6Any { + s.DupTentativeAddrDetected(e.nicID, targetAddr) + } + + // Do not handle neighbor solicitations targeted to an address that is + // tentative on the NIC any further. + return + } + + // At this point we know that the target address is not tentative on the NIC + // so the packet is processed as defined in RFC 4861, as per RFC 4862 + // section 5.4.3. + + // Is the NS targetting us? + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + return + } + + // If the NS message contains the Source Link-Layer Address option, update + // the link address cache with the value of the option. + // + // TODO(b/148429853): Properly process the NS message and do Neighbor + // Unreachability Detection. + var sourceLinkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) + } + if done { + break + } + + switch opt := opt.(type) { + case header.NDPSourceLinkLayerAddressOption: + // No RFCs define what to do when an NS message has multiple Source + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(sourceLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + sourceLinkAddr = opt.EthernetAddress() + } + } + + unspecifiedSource := r.RemoteAddress == header.IPv6Any + + // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, on link layers that have addresses this option MUST be + // included in multicast solicitations and SHOULD be included in unicast + // solicitations. + if len(sourceLinkAddr) == 0 { + if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + received.Invalid.Increment() + return + } + } else if unspecifiedSource { + received.Invalid.Increment() + return + } else { + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) + } + + // ICMPv6 Neighbor Solicit messages are always sent to + // specially crafted IPv6 multicast addresses. As a result, the + // route we end up with here has as its LocalAddress such a + // multicast address. It would be nonsense to claim that our + // source address is a multicast address, so we manually set + // the source address to the target address requested in the + // solicit message. Since that requires mutating the route, we + // must first clone it. + r := r.Clone() + defer r.Release() + r.LocalAddress = targetAddr + + // As per RFC 4861 section 7.2.4, if the the source of the solicitation is + // the unspecified address, the node MUST set the Solicited flag to zero and + // multicast the advertisement to the all-nodes address. + solicited := true + if unspecifiedSource { + solicited = false + r.RemoteAddress = header.IPv6AllNodesMulticastAddress + } + + // If the NS has a source link-layer option, use the link address it + // specifies as the remote link address for the response instead of the + // source link address of the packet. + // + // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link + // address cache for the right destination link address instead of manually + // patching the route with the remote link address if one is specified in a + // Source Link-Layer Address option. + if len(sourceLinkAddr) != 0 { + r.RemoteLinkAddress = sourceLinkAddr + } + + optsSerializer := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) + na.SetSolicitedFlag(solicited) + na.SetOverrideFlag(true) + na.SetTargetAddress(targetAddr) + opts := na.Options() + opts.Serialize(optsSerializer) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) + // + // 7.1.2. Validation of Neighbor Advertisements + // + // The IP Hop Limit field has a value of 255, i.e., the packet + // could not possibly have been forwarded by a router. + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + }); err != nil { + sent.Dropped.Increment() + return + } + sent.NeighborAdvert.Increment() + + case header.ICMPv6NeighborAdvert: + received.NeighborAdvert.Increment() + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + received.Invalid.Increment() + return + } + + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) + it, err := na.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.Invalid.Increment() + return + } + + targetAddr := na.TargetAddress() + stack := r.Stack() + + if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil { + // We will only get an error if the NIC is unrecognized, which should not + // happen. For now short-circuit this packet. + // + // TODO(b/141002840): Handle this better? + return + } else if isTentative { + // We just got an NA from a node that owns an address we are performing + // DAD on, implying the address is not unique. In this case we let the + // stack know so it can handle such a scenario and do nothing furthur with + // the NDP NA. + stack.DupTentativeAddrDetected(e.nicID, targetAddr) + return + } + + // At this point we know that the target address is not tentative on the + // NIC. However, the target address may still be assigned to the NIC but not + // tentative (it could be permanent). Such a scenario is beyond the scope of + // RFC 4862. As such, we simply ignore such a scenario for now and proceed + // as normal. + // + // TODO(b/143147598): Handle the scenario described above. Also inform the + // netstack integration that a duplicate address was detected outside of + // DAD. + + // If the NA message has the target link layer option, update the link + // address cache with the link address for the target of the message. + // + // TODO(b/148429853): Properly process the NA message and do Neighbor + // Unreachability Detection. + var targetLinkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) + } + if done { + break + } + + switch opt := opt.(type) { + case header.NDPTargetLinkLayerAddressOption: + // No RFCs define what to do when an NA message has multiple Target + // Link-Layer Address options. Since no interface can have multiple + // link-layer addresses, we consider such messages invalid. + if len(targetLinkAddr) != 0 { + received.Invalid.Increment() + return + } + + targetLinkAddr = opt.EthernetAddress() + } + } + + if len(targetLinkAddr) != 0 { + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + } + + case header.ICMPv6EchoRequest: + received.EchoRequest.Increment() + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, icmpHdr) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { + sent.Dropped.Increment() + return + } + sent.EchoReply.Increment() + + case header.ICMPv6EchoReply: + received.EchoReply.Increment() + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + received.Invalid.Increment() + return + } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) + + case header.ICMPv6TimeExceeded: + received.TimeExceeded.Increment() + + case header.ICMPv6ParamProblem: + received.ParamProblem.Increment() + + case header.ICMPv6RouterSolicit: + received.RouterSolicit.Increment() + if !isNDPValid() { + received.Invalid.Increment() + return + } + + case header.ICMPv6RouterAdvert: + received.RouterAdvert.Increment() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + received.Invalid.Increment() + return + } + + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + + case header.ICMPv6RedirectMsg: + received.RedirectMsg.Increment() + if !isNDPValid() { + received.Invalid.Increment() + return + } + + default: + received.Invalid.Increment() + } +} + +const ( + ndpSolicitedFlag = 1 << 6 + ndpOverrideFlag = 1 << 5 + + ndpOptSrcLinkAddr = 1 + ndpOptDstLinkAddr = 2 + + icmpV6FlagOffset = 4 + icmpV6OptOffset = 24 + icmpV6LengthOffset = 25 +) + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +var _ stack.LinkAddressResolver = (*protocol)(nil) + +// LinkAddressProtocol implements stack.LinkAddressResolver. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + snaddr := header.SolicitedNodeAddr(addr) + + // TODO(b/148672031): Use stack.FindRoute instead of manually creating the + // route here. Note, we would need the nicID to do this properly so the right + // NIC (associated to linkEP) is used to send the NDP NS message. + r := &stack.Route{ + LocalAddress: localAddr, + RemoteAddress: snaddr, + RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), + } + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + copy(pkt[icmpV6OptOffset-len(addr):], addr) + pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr + pkt[icmpV6LengthOffset] = 1 + copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(hdr.UsedLength()) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + + // TODO(stijlist): count this in ICMP stats. + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + Header: hdr, + }) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return tcpip.LinkAddress([]byte(nil)), false +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go new file mode 100644 index 000000000..d412ff688 --- /dev/null +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -0,0 +1,960 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "context" + "reflect" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") +) + +var ( + lladdr0 = header.LinkLocalAddr(linkAddr0) + lladdr1 = header.LinkLocalAddr(linkAddr1) +) + +type stubLinkEndpoint struct { + stack.LinkEndpoint +} + +func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +func (*stubLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { + return nil +} + +func (*stubLinkEndpoint) Attach(stack.NetworkDispatcher) {} + +type stubDispatcher struct { + stack.TransportDispatcher +} + +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +} + +type stubLinkAddressCache struct { + stack.LinkAddressCache +} + +func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID { + return 0 +} + +func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { +} + +func TestICMPCounts(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) + { + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) + if err != nil { + t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + } + + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } + defer r.Release() + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + typ header.ICMPv6Type + size int + extraData []byte + }{ + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, + } + + handleIPv6Payload := func(hdr buffer.Prependable) { + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(&r, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + for _, typ := range types { + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + + handleIPv6Payload(hdr) + } + + // Construct an empty ICMP packet so that + // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. + handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + + icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { + if got, want := s.Value(), uint64(1); got != want { + t.Errorf("got %s = %d, want = %d", name, got, want) + } + }) + if t.Failed() { + t.Logf("stats:\n%+v", s.Stats()) + } +} + +func visitStats(v reflect.Value, f func(string, *tcpip.StatCounter)) { + t := v.Type() + for i := 0; i < v.NumField(); i++ { + v := v.Field(i) + if s, ok := v.Interface().(*tcpip.StatCounter); ok { + f(t.Field(i).Name, s) + } else { + visitStats(v, f) + } + } +} + +type testContext struct { + s0 *stack.Stack + s1 *stack.Stack + + linkEP0 *channel.Endpoint + linkEP1 *channel.Endpoint +} + +type endpointWithResolutionCapability struct { + stack.LinkEndpoint +} + +func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { + return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired +} + +func newTestContext(t *testing.T) *testContext { + c := &testContext{ + s0: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + s1: stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }), + } + + const defaultMTU = 65536 + c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + + wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) + if testing.Verbose() { + wrappedEP0 = sniffer.New(wrappedEP0) + } + if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { + t.Fatalf("CreateNIC s0: %v", err) + } + if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress lladdr0: %v", err) + } + + c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) + if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { + t.Fatalf("AddAddress lladdr1: %v", err) + } + + subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + c.s0.SetRouteTable( + []tcpip.Route{{ + Destination: subnet0, + NIC: 1, + }}, + ) + subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } + c.s1.SetRouteTable( + []tcpip.Route{{ + Destination: subnet1, + NIC: 1, + }}, + ) + + return c +} + +func (c *testContext) cleanup() { + c.linkEP0.Close() + c.linkEP1.Close() +} + +type routeArgs struct { + src, dst *channel.Endpoint + typ header.ICMPv6Type + remoteLinkAddr tcpip.LinkAddress +} + +func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { + t.Helper() + + pi, _ := args.src.ReadContext(context.Background()) + + { + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() + vv := buffer.NewVectorisedView(size, views) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ + Data: vv, + }) + } + + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) + return + } + + if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + } + + ipv6 := header.IPv6(pi.Pkt.Header.View()) + transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) + if transProto != header.ICMPv6ProtocolNumber { + t.Errorf("unexpected transport protocol number %d", transProto) + return + } + icmpv6 := header.ICMPv6(ipv6.Payload()) + if got, want := icmpv6.Type(), args.typ; got != want { + t.Errorf("got ICMPv6 type = %d, want = %d", got, want) + return + } + if fn != nil { + fn(t, icmpv6) + } +} + +func TestLinkResolution(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + + r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } + defer r.Release() + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + payload := tcpip.SlicePayload(hdr.View()) + + // We can't send our payload directly over the route because that + // doesn't provoke NDP discovery. + var wq waiter.Queue + ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + } + + for { + _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) + if resCh != nil { + if err != tcpip.ErrNoLinkAddress { + t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) + } + for _, args := range []routeArgs{ + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, + {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, + } { + routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { + t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) + } + }) + } + <-resCh + continue + } + if err != nil { + t.Fatalf("ep.Write(_) = _, _, %s", err) + } + break + } + + for _, args := range []routeArgs{ + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, + {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, + } { + routeICMPv6Packet(t, args, nil) + } +} + +func TestICMPChecksumValidationSimple(t *testing.T) { + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + name: "DstUnreachable", + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + }, + { + name: "PacketTooBig", + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + }, + { + name: "TimeExceeded", + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + }, + { + name: "ParamProblem", + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + }, + { + name: "EchoRequest", + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + }, + { + name: "EchoReply", + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + }, + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(typ.size + extraDataLen), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayload(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + icmpSize := size + payloadSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(typ) + payloadFn(pkt.Payload()) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} + +func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { + const simpleBodySize = 64 + simpleBody := func(view buffer.View) { + for i := 0; i < simpleBodySize; i++ { + view[i] = uint8(i) + } + } + + const errorICMPBodySize = header.IPv6MinimumSize + simpleBodySize + errorICMPBody := func(view buffer.View) { + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: simpleBodySize, + NextHeader: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, + }) + simpleBody(view[header.IPv6MinimumSize:]) + } + + types := []struct { + name string + typ header.ICMPv6Type + size int + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + payloadSize int + payload func(buffer.View) + }{ + { + "DstUnreachable", + header.ICMPv6DstUnreachable, + header.ICMPv6DstUnreachableMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.DstUnreachable + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "PacketTooBig", + header.ICMPv6PacketTooBig, + header.ICMPv6PacketTooBigMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.PacketTooBig + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "TimeExceeded", + header.ICMPv6TimeExceeded, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.TimeExceeded + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "ParamProblem", + header.ICMPv6ParamProblem, + header.ICMPv6MinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.ParamProblem + }, + errorICMPBodySize, + errorICMPBody, + }, + { + "EchoRequest", + header.ICMPv6EchoRequest, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoRequest + }, + simpleBodySize, + simpleBody, + }, + { + "EchoReply", + header.ICMPv6EchoReply, + header.ICMPv6EchoMinimumSize, + func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.EchoReply + }, + simpleBodySize, + simpleBody, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr0) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) + pkt := header.ICMPv6(hdr.Prepend(size)) + pkt.SetType(typ) + + payload := buffer.NewView(payloadSize) + payloadFn(payload) + + if checksum { + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + } + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(size + payloadSize), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go new file mode 100644 index 000000000..daf1fcbc6 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -0,0 +1,521 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ipv6 contains the implementation of the ipv6 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv6.NewProtocol() as one of the network +// protocols when calling stack.New(). Then endpoints can be created by passing +// ipv6.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv6 + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/hash" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // ProtocolNumber is the ipv6 protocol number. + ProtocolNumber = header.IPv6ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // PayloadLength field of the ipv6 header. + maxPayloadSize = 0xffff + + // DefaultTTL is the default hop limit for IPv6 Packets egressed by + // Netstack. + DefaultTTL = 64 +) + +type endpoint struct { + nicID tcpip.NICID + id stack.NetworkEndpointID + prefixLen int + linkEP stack.LinkEndpoint + linkAddrCache stack.LinkAddressCache + dispatcher stack.TransportDispatcher + fragmentation *fragmentation.Fragmentation + protocol *protocol +} + +// DefaultTTL is the default hop limit for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return e.protocol.DefaultTTL() +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// ID returns the ipv6 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// PrefixLen returns the ipv6 endpoint subnet prefix length in bits. +func (e *endpoint) PrefixLen() int { + return e.prefixLen +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength returns the maximum length needed by ipv6 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.linkEP.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { + length := uint16(hdr.UsedLength() + payloadSize) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(params.Protocol), + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + return ip +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + + if r.Loop&stack.PacketLoop != 0 { + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + loopedR := r.MakeLoopedRoute() + + e.HandlePacket(&loopedR, stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) + + loopedR.Release() + } + if r.Loop&stack.PacketOut == 0 { + return nil + } + + r.Stats().IP.PacketsSent.Increment() + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { + panic("not implemented") + } + if r.Loop&stack.PacketOut == 0 { + return pkts.Len(), nil + } + + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) + pb.NetworkHeader = buffer.View(ip) + } + + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err +} + +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet +// supported by IPv6. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { + // TODO(b/146666412): Support IPv6 header-included packets. + return tcpip.ErrNotSupported +} + +// HandlePacket is called by the link layer when new ipv6 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + h := header.IPv6(headerView) + if !h.IsValid(pkt.Data.Size()) { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) + + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + hasFragmentHeader := false + + for firstHeader := true; ; firstHeader = false { + extHdr, done, err := it.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6HopByHopOptionsExtHdr: + // As per RFC 8200 section 4.1, the Hop By Hop extension header is + // restricted to appear immediately after an IPv6 fixed header. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 + // (unrecognized next header) error in response to an extension header's + // Next Header field with the Hop By Hop extension header identifier. + if !firstHeader { + return + } + + optsIt := extHdr.Iter() + + for { + opt, done, err := optsIt.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + // We currently do not support any IPv6 Hop By Hop extension header + // options. + switch opt.UnknownAction() { + case header.IPv6OptionUnknownActionSkip: + case header.IPv6OptionUnknownActionDiscard: + return + case header.IPv6OptionUnknownActionDiscardSendICMP: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + default: + panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt)) + } + } + + case header.IPv6RoutingExtHdr: + // As per RFC 8200 section 4.4, if a node encounters a routing header with + // an unrecognized routing type value, with a non-zero Segments Left + // value, the node must discard the packet and send an ICMP Parameter + // Problem, Code 0. If the Segments Left is 0, the node must ignore the + // Routing extension header and process the next header in the packet. + // + // Note, the stack does not yet handle any type of routing extension + // header, so we just make sure Segments Left is zero before processing + // the next extension header. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for + // unrecognized routing types with a non-zero Segments Left value. + if extHdr.SegmentsLeft() != 0 { + return + } + + case header.IPv6FragmentExtHdr: + hasFragmentHeader = true + + fragmentOffset := extHdr.FragmentOffset() + more := extHdr.More() + if !more && fragmentOffset == 0 { + // This fragment extension header indicates that this packet is an + // atomic fragment. An atomic fragment is a fragment that contains + // all the data required to reassemble a full packet. As per RFC 6946, + // atomic fragments must not interfere with "normal" fragmented traffic + // so we skip processing the fragment instead of feeding it through the + // reassembly process below. + continue + } + + // Don't consume the iterator if we have the first fragment because we + // will use it to validate that the first fragment holds the upper layer + // header. + rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + + if fragmentOffset == 0 { + // Check that the iterator ends with a raw payload as the first fragment + // should include all headers up to and including any upper layer + // headers, as per RFC 8200 section 4.5; only upper layer data + // (non-headers) should follow the fragment extension header. + var lastHdr header.IPv6PayloadHeader + + for { + it, done, err := it.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + lastHdr = it + } + + // If the last header is a raw header, then the last portion of the IPv6 + // payload is not a known IPv6 extension header. Note, this does not + // mean that the last portion is an upper layer header or not an + // extension header because: + // 1) we do not yet support all extension headers + // 2) we do not validate the upper layer header before reassembling. + // + // This check makes sure that a known IPv6 extension header is not + // present after the Fragment extension header in a non-initial + // fragment. + // + // TODO(#2196): Support IPv6 Authentication and Encapsulated + // Security Payload extension headers. + // TODO(#2333): Validate that the upper layer header is valid. + switch lastHdr.(type) { + case header.IPv6RawPayloadHeader: + default: + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + } + + fragmentPayloadLen := rawPayload.Buf.Size() + if fragmentPayloadLen == 0 { + // Drop the packet as it's marked as a fragment but has no payload. + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + // The packet is a fragment, let's try to reassemble it. + start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + last := start + uint16(fragmentPayloadLen) - 1 + + // Drop the packet if the fragmentOffset is incorrect. i.e the + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. + if last < start { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + var ready bool + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + return + } + + if ready { + // We create a new iterator with the reassembled packet because we could + // have more extension headers in the reassembled payload, as per RFC + // 8200 section 4.5. + it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data) + } + + case header.IPv6DestinationOptionsExtHdr: + optsIt := extHdr.Iter() + + for { + opt, done, err := optsIt.Next() + if err != nil { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + if done { + break + } + + // We currently do not support any IPv6 Destination extension header + // options. + switch opt.UnknownAction() { + case header.IPv6OptionUnknownActionSkip: + case header.IPv6OptionUnknownActionDiscard: + return + case header.IPv6OptionUnknownActionDiscardSendICMP: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: + // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for + // unrecognized IPv6 extension header options. + return + default: + panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt)) + } + } + + case header.IPv6RawPayloadHeader: + // If the last header in the payload isn't a known IPv6 extension header, + // handle it as if it is transport layer data. + pkt.Data = extHdr.Buf + + if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { + e.handleICMP(r, headerView, pkt, hasFragmentHeader) + } else { + r.Stats().IP.PacketsDelivered.Increment() + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error + // in response to unrecognized next header values. + e.dispatcher.DeliverTransportPacket(r, p, pkt) + } + + default: + // If we receive a packet for an extension header we do not yet handle, + // drop the packet for now. + // + // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error + // in response to unrecognized next header values. + r.Stats().UnknownProtocolRcvdPackets.Increment() + return + } + } +} + +// Close cleans up resources associated with the endpoint. +func (*endpoint) Close() {} + +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + +type protocol struct { + // defaultTTL is the current default TTL for the protocol. Only the + // uint8 portion of it is meaningful and it must be accessed + // atomically. + defaultTTL uint32 +} + +// Number returns the ipv6 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv6 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen returns the IPv6 default prefix length. +func (p *protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint creates a new ipv6 endpoint. +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { + return &endpoint{ + nicID: nicID, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + linkAddrCache: linkAddrCache, + dispatcher: dispatcher, + fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + protocol: p, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(v)) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(p.DefaultTTL()) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// SetDefaultTTL sets the default TTL for endpoints created with this protocol. +func (p *protocol) SetDefaultTTL(ttl uint8) { + atomic.StoreUint32(&p.defaultTTL, uint32(ttl)) +} + +// DefaultTTL returns the default TTL for endpoints created with this protocol. +func (p *protocol) DefaultTTL() uint8 { + return uint8(atomic.LoadUint32(&p.defaultTTL)) +} + +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + mtu -= header.IPv6MinimumSize + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +// NewProtocol returns an IPv6 network protocol. +func NewProtocol() stack.NetworkProtocol { + return &protocol{defaultTTL: DefaultTTL} +} diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go new file mode 100644 index 000000000..841a0cb7a --- /dev/null +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -0,0 +1,1265 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // The least significant 3 bytes are the same as addr2 so both addr2 and + // addr3 will have the same solicited-node address. + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" + addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" + + // Tests use the extension header identifier values as uint8 instead of + // header.IPv6ExtensionHeaderIdentifier. + hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier) + routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier) + fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) + destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) + noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) +) + +// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the +// expected Neighbor Advertisement received count after receiving the packet. +func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + // Receive ICMP packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stats := s.Stats().ICMP.V6PacketsReceived + + if got := stats.NeighborAdvert.Value(); got != want { + t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) + } +} + +// testReceiveUDP tests receiving a UDP packet from src to dst. want is the +// expected UDP received count after receiving the packet. +func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) { + t.Helper() + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + + // Receive UDP Packet. + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize) + u := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: header.UDPMinimumSize, + }) + + // UDP pseudo-header checksum. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) + + // UDP checksum + sum = header.Checksum(header.UDP([]byte{}), sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stat := s.Stats().UDP.PacketsReceived + + if got := stat.Value(); got != want { + t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want) + } +} + +// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and +// UDP packets destined to the IPv6 link-local all-nodes multicast address. +func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Should receive a packet destined to the all-nodes + // multicast address. + test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1) + }) + } +} + +// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP +// packets destined to the IPv6 solicited-node address of an assigned IPv6 +// address. +func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + protocolFactory stack.TransportProtocol + rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) + }{ + {"ICMP", icmp.NewProtocol6(), testReceiveICMP}, + {"UDP", udp.NewProtocol(), testReceiveUDP}, + } + + snmc := header.SolicitedNodeAddr(addr2) + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, + }) + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as we haven't added those addresses. + test.rxf(t, s, e, addr1, snmc, 0) + + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + // Should receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added added addr2. + test.rxf(t, s, e, addr1, snmc, 1) + + if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) + } + + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added addr3. + test.rxf(t, s, e, addr1, snmc, 2) + + if err := s.RemoveAddress(nicID, addr2); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) + } + + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have removed addr2. + test.rxf(t, s, e, addr1, snmc, 3) + + // Make sure addr3's endpoint does not get removed from the NIC by + // incrementing its reference count with a route. + r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) + } + defer r.Release() + + if err := s.RemoveAddress(nicID, addr3); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) + } + + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as both of them got removed, even though a route using + // addr3 exists. + test.rxf(t, s, e, addr1, snmc, 3) + }) + } +} + +// TestAddIpv6Address tests adding IPv6 addresses. +func TestAddIpv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + }{ + // This test is in response to b/140943433. + { + "Nil", + tcpip.Address([]byte(nil)), + }, + { + "ValidUnicast", + addr1, + }, + { + "ValidLinkLocalUnicast", + lladdr0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil { + t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if addr.Address != test.addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr) + } + }) + } +} + +func TestReceiveIPv6ExtHdrs(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + extHdr func(nextHdr uint8) ([]byte, uint8) + shouldAccept bool + }{ + { + name: "None", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, + shouldAccept: true, + }, + { + name: "hopbyhop with unknown option skippable action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: true, + }, + { + name: "hopbyhop with unknown option discard action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "hopbyhop with unknown option discard and send icmp action unless multicast dest", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "routing with zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID }, + shouldAccept: true, + }, + { + name: "routing with non-zero segments left", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID }, + shouldAccept: false, + }, + { + name: "atomic fragment with zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID }, + shouldAccept: true, + }, + { + name: "atomic fragment with non-zero ID", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + shouldAccept: true, + }, + { + name: "fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID }, + shouldAccept: false, + }, + { + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + shouldAccept: false, + }, + { + name: "destination with unknown option skippable action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Skippable unknown. + 62, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: true, + }, + { + name: "destination with unknown option discard action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard unknown. + 127, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: false, + }, + { + name: "destination with unknown option discard and send icmp action", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP if option is unknown. + 191, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: false, + }, + { + name: "destination with unknown option discard and send icmp action unless multicast dest", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 1, + + // Skippable unknown. + 63, 4, 1, 2, 3, 4, + + // Discard & send ICMP unless packet is for multicast destination if + // option is unknown. + 255, 6, 1, 2, 3, 4, 5, 6, + }, destinationExtHdrID + }, + shouldAccept: false, + }, + { + name: "routing - atomic fragment", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + nextHdr, 0, 0, 0, 1, 2, 3, 4, + }, routingExtHdrID + }, + shouldAccept: true, + }, + { + name: "atomic fragment - routing", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Fragment extension header. + routingExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Routing extension header. + nextHdr, 0, 1, 0, 2, 3, 4, 5, + }, fragmentExtHdrID + }, + shouldAccept: true, + }, + { + name: "hop by hop (with skippable unknown) - routing", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with skippable unknown option. + routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, + + // Routing extension header. + nextHdr, 0, 1, 0, 2, 3, 4, 5, + }, hopByHopExtHdrID + }, + shouldAccept: true, + }, + { + name: "routing - hop by hop (with skippable unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Routing extension header. + hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Hop By Hop extension header with skippable unknown option. + nextHdr, 0, 62, 4, 1, 2, 3, 4, + }, routingExtHdrID + }, + shouldAccept: false, + }, + { + name: "No next header", + extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, + shouldAccept: false, + }, + { + name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with skippable unknown option. + routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, + + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Destination extension header with skippable unknown option. + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: true, + }, + { + name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with discard action for unknown option. + routingExtHdrID, 0, 65, 4, 1, 2, 3, 4, + + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Destination extension header with skippable unknown option. + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + { + name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + // Hop By Hop extension header with skippable unknown option. + routingExtHdrID, 0, 62, 4, 1, 2, 3, 4, + + // Routing extension header. + fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5, + + // Fragment extension header. + destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4, + + // Destination extension header with discard action for unknown + // option. + nextHdr, 0, 65, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8} + udpLength := header.UDPMinimumSize + len(udpPayload) + extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber)) + extHdrLen := len(extHdrBytes) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength) + + // Serialize UDP message. + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), udpPayload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum = header.Checksum(udpPayload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + // Copy extension header bytes between the UDP message and the IPv6 + // fixed header. + copy(hdr.Prepend(extHdrLen), extHdrBytes) + + // Serialize IPv6 fixed header. + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: ipv6NextHdr, + HopLimit: 255, + SrcAddr: addr1, + DstAddr: addr2, + }) + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + stats := s.Stats().UDP.PacketsReceived + + if !test.shouldAccept { + if got := stats.Value(); got != 0 { + t.Errorf("got UDP Rx Packets = %d, want = 0", got) + } + + return + } + + // Expect a UDP packet. + if got := stats.Value(); got != 1 { + t.Errorf("got UDP Rx Packets = %d, want = 1", got) + } + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + + // Should not have any more UDP packets. + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} + +// fragmentData holds the IPv6 payload for a fragmented IPv6 packet. +type fragmentData struct { + nextHdr uint8 + data buffer.VectorisedView +} + +func TestReceiveIPv6Fragments(t *testing.T) { + const nicID = 1 + const udpPayload1Length = 256 + const udpPayload2Length = 128 + const fragmentExtHdrLen = 8 + // Note, not all routing extension headers will be 8 bytes but this test + // uses 8 byte routing extension headers for most sub tests. + const routingExtHdrLen = 8 + + udpGen := func(payload []byte, multiplier uint8) buffer.View { + payloadLen := len(payload) + for i := 0; i < payloadLen; i++ { + payload[i] = uint8(i) * multiplier + } + + udpLength := header.UDPMinimumSize + payloadLen + + hdr := buffer.NewPrependable(udpLength) + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), payload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum = header.Checksum(payload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + return hdr.View() + } + + var udpPayload1Buf [udpPayload1Length]byte + udpPayload1 := udpPayload1Buf[:] + ipv6Payload1 := udpGen(udpPayload1, 1) + + var udpPayload2Buf [udpPayload2Length]byte + udpPayload2 := udpPayload2Buf[:] + ipv6Payload2 := udpGen(udpPayload2, 2) + + tests := []struct { + name string + expectedPayload []byte + fragments []fragmentData + expectedPayloads [][]byte + }{ + { + name: "No fragmentation", + fragments: []fragmentData{ + { + nextHdr: uint8(header.UDPProtocolNumber), + data: ipv6Payload1.ToVectorisedView(), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Atomic fragment", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload1, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with different IDs", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with per-fragment routing header with zero segments left", + fragments: []fragmentData{ + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with per-fragment routing header with non-zero segments left", + fragments: []fragmentData{ + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: routingExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with zero segments left", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header. + // + // Segments left = 0. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Two fragments with routing header with non-zero segments left", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + routingExtHdrLen+fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header. + // + // Segments left = 1. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 9, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with zero segments left across fragments", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is fragmentExtHdrLen+8 because the + // first 8 bytes of the 16 byte routing extension header is in + // this fragment. + fragmentExtHdrLen+8, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header (part 1) + // + // Segments left = 0. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is + // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // the 16 byte routing extension header is in this fagment. + fragmentExtHdrLen+8+len(ipv6Payload1), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 1, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + + // Routing extension header (part 2) + buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + + ipv6Payload1, + }, + ), + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with routing header with non-zero segments left across fragments", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is fragmentExtHdrLen+8 because the + // first 8 bytes of the 16 byte routing extension header is in + // this fragment. + fragmentExtHdrLen+8, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + + // Routing extension header (part 1) + // + // Segments left = 1. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + // The length of this payload is + // fragmentExtHdrLen+8+len(ipv6Payload1) because the last 8 bytes of + // the 16 byte routing extension header is in this fagment. + fragmentExtHdrLen+8+len(ipv6Payload1), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 1, More = false, ID = 1 + buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + + // Routing extension header (part 2) + buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + + ipv6Payload1, + }, + ), + }, + }, + expectedPayloads: nil, + }, + // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal" + // fragmented traffic. + { + name: "Two fragments with atomic", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + // This fragment has the same ID as the other fragments but is an atomic + // fragment. It should not interfere with the other fragments. + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload2), + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), + + ipv6Payload2, + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload2, udpPayload1}, + }, + { + name: "Two interleaved fragmented packets", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), + + ipv6Payload2[:32], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1[64:], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload2)-32, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 4, More = false, ID = 2 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), + + ipv6Payload2[32:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + + // Serialize IPv6 fixed header. + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(f.data.Size()), + NextHeader: f.nextHdr, + HopLimit: 255, + SrcAddr: addr1, + DstAddr: addr2, + }) + + vv := hdr.View().ToVectorisedView() + vv.Append(f.data) + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: vv, + }) + } + + if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { + t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) + } + + for i, p := range test.expectedPayloads { + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("(i=%d) Read(nil): %s", i, err) + } + if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" { + t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + } + } + + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go new file mode 100644 index 000000000..12b70f7e9 --- /dev/null +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -0,0 +1,906 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" +) + +// setupStackAndEndpoint creates a stack with a single NIC with a link-local +// address llladdr and an IPv6 endpoint to a remote with link-local address +// rlladdr +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { + t.Helper() + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + }) + + if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) + } + + { + subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + + ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) + if err != nil { + t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + } + + return s, ep +} + +// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +// valid NDP NS message with the Source Link Layer Address option results in a +// new entry in the link address cache for the sender of the message. +func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) + } + + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNeighorSolicitationResponse(t *testing.T) { + const nicID = 1 + nicAddr := lladdr0 + remoteAddr := lladdr1 + nicAddrSNMC := header.SolicitedNodeAddr(nicAddr) + nicLinkAddr := linkAddr0 + remoteLinkAddr0 := linkAddr1 + remoteLinkAddr1 := linkAddr2 + + tests := []struct { + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + }{ + { + name: "Unspecified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Unspecified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: false, + naSrc: nicAddr, + naDst: header.IPv6AllNodesMulticastAddress, + }, + { + name: "Unspecified source with source ll option to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: header.IPv6Any, + nsDst: nicAddr, + nsInvalid: true, + }, + + { + name: "Specified source with 1 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source to multicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + { + name: "Specified source with 2 source ll to multicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddrSNMC, + nsInvalid: true, + }, + + { + name: "Specified source to unicast destination", + nsOpts: nil, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr0, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 1 source ll different from route to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: false, + naDstLinkAddr: remoteLinkAddr1, + naSolicited: true, + naSrc: nicAddr, + naDst: remoteAddr, + }, + { + name: "Specified source with 2 source ll to unicast destination", + nsOpts: header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]), + header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]), + }, + nsSrcLinkAddr: remoteLinkAddr0, + nsSrc: remoteAddr, + nsDst: nicAddr, + nsInvalid: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(1, 1280, nicLinkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } + + // If we expected the NS to be invalid, we have nothing else to check. + return + } + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + p, got := e.Read() + if !got { + t.Fatal("expected an NDP NA response") + } + + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } + + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) + } +} + +// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option results in a +// new entry in the link address cache for the target of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, + }, + { + name: "Multiple", + optsBuf: []byte{ + 2, 1, 2, 3, 4, 5, 6, 7, + 2, 1, 2, 3, 4, 5, 6, 8, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) + } + + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNDPValidation(t *testing.T) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + t.Helper() + + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) + + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } + + return s, ep, r + } + + handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + nextHdr := uint8(header.ICMPv6ProtocolNumber) + if atomicFragment { + bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) + bytes[0] = nextHdr + nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + } + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: nextHdr, + HopLimit: hopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(r, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + }{ + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code uint8 + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } + + for _, typ := range types { + t.Run(typ.name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + typStat := typ.statCounter(stats) + + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) + pkt.SetCode(test.code) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + }) + } + }) + } +} + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + } + }) + } +} |