diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/checker/checker.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/link/ethernet/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/link/ethernet/ethernet.go | 99 | ||||
-rw-r--r-- | pkg/tcpip/link/pipe/pipe.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/network/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 414 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 147 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 302 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp_test.go | 254 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 277 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 2 |
21 files changed, 1541 insertions, 229 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d4d785cca..6f81b0164 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker { } } +// IPPayload creates a checker that checks the payload. +func IPPayload(payload []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + got := h[0].Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(got) == 0 && len(payload) == 0 { + return + } + + if diff := cmp.Diff(payload, got); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + } +} + // IPv4Options returns a checker that checks the options in an IPv4 packet. func IPv4Options(want []byte) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD new file mode 100644 index 000000000..ec92ed623 --- /dev/null +++ b/pkg/tcpip/link/ethernet/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "ethernet", + srcs = ["ethernet.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go new file mode 100644 index 000000000..3eef7cd56 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ethernet provides an implementation of an ethernet link endpoint that +// wraps an inner link endpoint. +package ethernet + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkDispatcher = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns an ethernet link endpoint that wraps an inner link endpoint. +func New(ep stack.LinkEndpoint) *Endpoint { + var e Endpoint + e.Endpoint.Init(ep, &e) + return &e +} + +// Endpoint is an ethernet endpoint. +// +// It adds an ethernet header to packets before sending them out through its +// inner link endpoint and consumes an ethernet header before sending the +// packet to the stack. +type Endpoint struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + return + } + + eth := header.Ethernet(hdr) + if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt) + } +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + return e.Endpoint.WritePacket(r, gso, proto, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + linkAddr := e.Endpoint.LinkAddress() + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength() +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + fields := header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: proto, + } + eth.Encode(&fields) +} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 76f563811..523b0d24b 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -26,27 +26,23 @@ import ( var _ stack.LinkEndpoint = (*Endpoint)(nil) // New returns both ends of a new pipe. -func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) { +func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) { ep1 := &Endpoint{ - linkAddr: linkAddr1, - capabilities: capabilities, + linkAddr: linkAddr1, } ep2 := &Endpoint{ - linkAddr: linkAddr2, - linked: ep1, - capabilities: capabilities, + linkAddr: linkAddr2, } ep1.linked = ep2 + ep2.linked = ep1 return ep1, ep2 } // Endpoint is one end of a pipe. type Endpoint struct { - capabilities stack.LinkEndpointCapabilities - linkAddr tcpip.LinkAddress - dispatcher stack.NetworkDispatcher - linked *Endpoint - onWritePacket func(*stack.PacketBuffer) + dispatcher stack.NetworkDispatcher + linked *Endpoint + linkAddr tcpip.LinkAddress } // WritePacket implements stack.LinkEndpoint. @@ -55,16 +51,11 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network return nil } - // The pipe endpoint will accept all multicast/broadcast link traffic and only - // unicast traffic destined to itself. - if len(e.linked.linkAddr) != 0 && - r.RemoteLinkAddress != e.linked.linkAddr && - r.RemoteLinkAddress != header.EthernetBroadcastAddress && - !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) { - return nil - } - - e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // Note that the local address from the perspective of this endpoint is the + // remote address from the perspective of the other end of the pipe + // (e.linked). Similarly, the remote address from the perspective of this + // endpoint is the local address on the other end. + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) @@ -100,8 +91,8 @@ func (*Endpoint) MTU() uint32 { } // Capabilities implements stack.LinkEndpoint. -func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities +func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 } // MaxHeaderLength implements stack.LinkEndpoint. @@ -116,7 +107,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // ARPHardwareType implements stack.LinkEndpoint. func (*Endpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareEther + return header.ARPHardwareNone } // AddHeader implements stack.LinkEndpoint. diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 59710352b..c118a2929 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -12,6 +12,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d436873b6..f20b94d97 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,11 +15,13 @@ package ip_test import ( + "strings" "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -320,6 +322,7 @@ func TestSourceAddressValidation(t *testing.T) { SrcAddr: src, DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -342,7 +345,6 @@ func TestSourceAddressValidation(t *testing.T) { SrcAddr: src, DstAddr: localIPv6Addr, }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -579,6 +581,7 @@ func TestIPv4Receive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -660,6 +663,7 @@ func TestIPv4ReceiveControl(t *testing.T) { SrcAddr: "\x0a\x00\x00\xbb", DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Create the ICMP header. icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) @@ -679,6 +683,7 @@ func TestIPv4ReceiveControl(t *testing.T) { SrcAddr: localIPv4Addr, DstAddr: remoteIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { @@ -732,6 +737,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip1.SetChecksum(^ip1.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag1[i] = uint8(i) @@ -748,6 +755,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip2.SetChecksum(^ip2.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag2[i] = uint8(i) @@ -1020,3 +1029,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer _, _ = pkt.NetworkHeader().Consume(netHdrLen) return pkt } + +func TestWriteHeaderIncludedPacket(t *testing.T) { + const ( + nicID = 1 + transportProto = 5 + + dataLen = 4 + optionsLen = 4 + ) + + dataBuf := [dataLen]byte{1, 2, 3, 4} + data := dataBuf[:] + + ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1} + ipv4Options := ipv4OptionsBuf[:] + + ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} + ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] + + var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte + ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] + if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + nicAddr tcpip.Address + remoteAddr tcpip.Address + pktGen func(*testing.T, tcpip.Address) buffer.View + checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) + expectedErr *tcpip.Error + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with IHL too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 1, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 minimum size", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(header.IPv4MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv4 with options", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) + totalLen := ipHdrLen + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(ipHdrLen)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHdrLen), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) + } + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6 with extension header", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), + checker.IPPayload(ipv6PayloadWithExtHdr), + ) + }, + }, + { + name: "IPv6 minimum size", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(header.IPv6MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv6 too small", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + srcAddr tcpip.Address + }{ + { + name: "unspecified source", + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + }, + { + name: "random source", + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + e := channel.New(1, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) + + r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + } + defer r.Release() + + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + })); err != test.expectedErr { + t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + test.checker(t, pkt.Pkt, subTest.srcAddr) + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c5ac7b8b5..e7c58ae0a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -190,29 +190,6 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments fragments pkt and writes the results on the link -// endpoint. The IP header must already present in the original packet. The mtu -// is the maximum size of the packets. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error { - networkHeader := header.IPv4(pkt.NetworkHeader().View()) - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) - - for { - fragPkt, more := buildNextFragment(&pf, networkHeader) - if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) - return err - } - r.Stats().IP.PacketsSent.Increment() - if !more { - break - } - } - - return nil -} - func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) length := uint16(pkt.Size()) @@ -234,10 +211,39 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } +func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { + return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. The mtu is the maximum size of the packets. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + return n, pf.RemainingFragmentCount(), nil + } + } +} + // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -273,8 +279,18 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) + + if e.packetMustBeFragmented(pkt, gso) { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -293,9 +309,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } - for pkt := pkts.Front(); pkt != nil; { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.addIPHeader(r, pkt, params) - pkt = pkt.Next() + if e.packetMustBeFragmented(pkt, gso) { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pkt + if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(pkt, fragPkt) + pkt = fragPkt + return nil + }); err != nil { + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err)) + } + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) + } } nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -347,30 +377,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacket writes a packet already containing a network -// header through the given route. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrMalformedHeader } ip := header.IPv4(h) - if !ip.IsValid(pkt.Data.Size()) { - return tcpip.ErrInvalidOptionValue - } // Always set the total length. - ip.SetTotalLength(uint16(pkt.Data.Size())) + pktSize := pkt.Data.Size() + ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + if ip.SourceAddress() == header.IPv4Any { ip.SetSourceAddress(r.LocalAddress) } - // Set the destination. If the packet already included a destination, - // it will be part of the route. + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. ip.SetDestinationAddress(r.RemoteAddress) // Set the packet ID when zero. @@ -387,19 +414,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if r.Loop&stack.PacketLoop != 0 { - e.HandlePacket(r, pkt.Clone()) - } - if r.Loop&stack.PacketOut == 0 { - return nil + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader } - if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return err - } - r.Stats().IP.PacketsSent.Increment() - return nil + return e.writePacket(r, nil /* gso */, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -415,6 +440,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9916d783f..fee11bb38 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,9 +15,9 @@ package ipv4_test import ( - "bytes" "context" "encoding/hex" + "fmt" "math" "net" "testing" @@ -39,6 +39,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const extraHeaderReserve = 50 + func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, @@ -118,6 +120,7 @@ func TestIPv4Sanity(t *testing.T) { tests := []struct { name string headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool maxTotalLength uint16 transportProtocol uint8 TTL uint8 @@ -133,6 +136,14 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, + { + name: "bad header checksum", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + badHeaderChecksum: true, + shouldFail: true, + }, // The TTL tests check that we are not rejecting an incoming packet // with a zero or one TTL, which has been a point of confusion in the // past as RFC 791 says: "If this field contains the value zero, then the @@ -243,7 +254,7 @@ func TestIPv4Sanity(t *testing.T) { // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, @@ -288,6 +299,12 @@ func TestIPv4Sanity(t *testing.T) { if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) } + ip.SetChecksum(0) + ipHeaderChecksum := ip.CalculateChecksum() + if test.badHeaderChecksum { + ipHeaderChecksum += 42 + } + ip.SetChecksum(^ipHeaderChecksum) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) @@ -369,11 +386,10 @@ func TestIPv4Sanity(t *testing.T) { // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) - vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // Make a complete array of the sourcePacket packet. + source := header.IPv4(packets[0].NetworkHeader().View()) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make @@ -382,82 +398,147 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI sourceCopy.SetChecksum(0) sourceCopy.SetFlagsFragmentOffset(0, 0) sourceCopy.SetTotalLength(0) - var offset uint16 // Build up an array of the bytes sent. - var reassembledPayload []byte + var reassembledPayload buffer.VectorisedView for i, packet := range packets { // Confirm that the packet is valid. allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + fragmentIPHeader := header.IPv4(allBytes.ToView()) + if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + if got := len(fragmentIPHeader); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) + if got := fragmentIPHeader.TransportProtocol(); got != proto { + return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } - if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) } - if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { - t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) + if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { + return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) + } + if wantFragments[i].more { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) + reassembledPayload.AppendView(packet.TransportHeader().View()) + reassembledPayload.Append(packet.Data) // Clear out the checksum and length from the ip because we can't compare // it. - sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + + expected := buffer.View(source[source.HeaderLength():]) + if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } + + return nil } -func TestFragmentation(t *testing.T) { - const ttl = 42 +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - transportHeaderLength int - extraHeaderReserveLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } +var fragmentationTests = []struct { + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + payloadSize int + wantFragments []fragmentInfo +}{ + { + description: "No Fragmentation", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 744, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso none", + mtu: 1280, + gso: &stack.GSO{Type: stack.GSONone}, + transportHeaderLength: 0, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 144, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 44, more: false}, + }, + }, + { + description: "Fragmented with MTU smaller than header", + mtu: 300, + gso: nil, + transportHeaderLength: 1000, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 280, more: true}, + {offset: 280, payloadSize: 280, more: true}, + {offset: 560, payloadSize: 280, more: true}, + {offset: 840, payloadSize: 280, more: true}, + {offset: 1120, payloadSize: 280, more: true}, + {offset: 1400, payloadSize: 100, more: false}, + }, + }, +} - for _, ft := range fragTests { +func TestFragmentationWritePacket(t *testing.T) { + const ttl = 42 + + for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -467,17 +548,101 @@ func TestFragmentation(t *testing.T) { if err != nil { t.Fatalf("r.WritePacket(_, _, _) = %s", err) } - - if got := len(ep.WrittenPackets); got != ft.expectedFrags { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) } if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - compareFragments(t, ep.WrittenPackets, source, ft.mtu) + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } +} + +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + writePacketsTests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + + for _, test := range writePacketsTests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if err != nil { + t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) + } + if n != wantTotalPackets { + t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } }) } } @@ -534,14 +699,14 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != expectedError { - t.Errorf("got WritePacket() = %s, want = %s", err, expectedError) + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError) } if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) @@ -1277,6 +1442,7 @@ func TestReceiveFragments(t *testing.T) { SrcAddr: frag.srcAddr, DstAddr: frag.dstAddr, }) + ip.SetChecksum(^ip.CalculateChecksum()) vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) @@ -1545,6 +1711,7 @@ func TestPacketQueing(t *testing.T) { SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -1588,6 +1755,7 @@ func TestPacketQueing(t *testing.T) { SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -1633,7 +1801,7 @@ func TestPacketQueing(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), NIC: nicID, }, diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 913b2140c..ead6bedcb 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -439,19 +439,19 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - if len(targetLinkAddr) != 0 { - if e.nud == nil { + if e.nud == nil { + if len(targetLinkAddr) != 0 { e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) - return } - - e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ - Solicited: na.SolicitedFlag(), - Override: na.OverrideFlag(), - IsRouter: na.RouterFlag(), - }) + return } + e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + Solicited: na.SolicitedFlag(), + Override: na.OverrideFlag(), + IsRouter: na.RouterFlag(), + }) + case header.ICMPv6EchoRequest: received.EchoRequest.Increment() icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 3affcc4e4..8dc33c560 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -101,14 +101,19 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { } -type stubNUDHandler struct{} +type stubNUDHandler struct { + probeCount int + confirmationCount int +} var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { + s.probeCount++ } -func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { + s.confirmationCount++ } func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { @@ -118,6 +123,12 @@ var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { stack.NetworkLinkEndpoint + + linkAddr tcpip.LinkAddress +} + +func (i *testInterface) LinkAddress() tcpip.LinkAddress { + return i.linkAddr } func (*testInterface) ID() tcpip.NICID { @@ -1492,3 +1503,240 @@ func TestPacketQueing(t *testing.T) { }) } } + +func TestCallsToNeighborCache(t *testing.T) { + tests := []struct { + name string + createPacket func() header.ICMPv6 + multicast bool + source tcpip.Address + destination tcpip.Address + wantProbeCount int + wantConfirmationCount int + }{ + { + name: "Unicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "The source link-layer address option SHOULD be included in unicast + // solicitations." - RFC 4861 section 4.3 + // + // A Neighbor Advertisement needs to be sent in response, but the + // Neighbor Cache shouldn't be updated since we have no useful + // information about the sender. + wantProbeCount: 0, + }, + { + name: "Unicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantProbeCount: 1, + }, + { + name: "Multicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + // "The source link-layer address option MUST be included in multicast + // solicitations." - RFC 4861 section 4.3 + wantProbeCount: 0, + }, + { + name: "Multicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + wantProbeCount: 1, + }, + { + name: "Unicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "When responding to unicast solicitations, the target link-layer + // address option can be omitted since the sender of the solicitation has + // the correct link-layer address; otherwise, it would not be able to + // send the unicast solicitation in the first place." + // - RFC 4861 section 4.4 + wantConfirmationCount: 1, + }, + { + name: "Unicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantConfirmationCount: 1, + }, + { + name: "Multicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + // "Target link-layer address MUST be included for multicast solicitations + // in order to avoid infinite Neighbor Solicitation "recursion" when the + // peer node does not have a cache entry to return a Neighbor + // Advertisements message." - RFC 4861 section 4.4 + wantConfirmationCount: 0, + }, + { + name: "Multicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + wantConfirmationCount: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: true, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + nudHandler := &stubNUDHandler{} + ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. + r.LocalAddress = test.destination + + icmp := test.createPacket() + icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.RemoteAddress, + DstAddr: r.LocalAddress, + }) + ep.HandlePacket(&r, pkt) + + // Confirm the endpoint calls the correct NUDHandler method. + if nudHandler.probeCount != test.wantProbeCount { + t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + } + if nudHandler.confirmationCount != test.wantConfirmationCount { + t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 2bd8f4ece..9670696c7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -387,7 +387,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s } func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) + return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) } // handleFragments fragments pkt and calls the handler function on each @@ -416,17 +416,18 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p } n++ if !more { - break + return n, pf.RemainingFragmentCount(), nil } } - - return n, 0, nil } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt, params.Protocol) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -468,7 +469,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -501,21 +502,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) if e.packetMustBeFragmented(pb, gso) { - current := pb - _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pb + if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. - pkts.InsertAfter(current, fragPkt) - current = current.Next() + pkts.InsertAfter(pb, fragPkt) + pb = fragPkt return nil - }) - if err != nil { + }); err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } - // The fragmented packet can be released. The rest of the packets can be - // processed. - pkts.Remove(pb) - pb = current + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) } } @@ -569,11 +569,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet -// supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - // TODO(b/146666412): Support IPv6 header-included packets. - return tcpip.ErrNotSupported +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { + // The packet already has an IP header, but there are a few required checks. + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return tcpip.ErrMalformedHeader + } + ip := header.IPv6(h) + + // Always set the payload length. + pktSize := pkt.Data.Size() + ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + + // Set the source address when zero. + if ip.SourceAddress() == header.IPv6Any { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. + ip.SetDestinationAddress(r.RemoteAddress) + + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + proto, _, _, _, ok := parse.IPv6(pkt) + if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader + } + + return e.writePacket(r, nil /* gso */, pkt, proto) } // HandlePacket is called by the link layer when new ipv6 packets arrive for diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index bee18d1a8..297868f24 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -49,6 +49,8 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + + extraHeaderReserve = 50 ) // testReceiveICMP tests receiving an ICMP packet from src to dst. want is the @@ -181,6 +183,9 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } + if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) + } if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) } @@ -208,8 +213,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB reassembledPayload.Append(fragment.Data) } - result := reassembledPayload.ToView() - if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" { + if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } @@ -2217,24 +2221,19 @@ type fragmentInfo struct { payloadSize uint16 } -type fragmentationTestCase struct { +var fragmentationTests = []struct { description string mtu uint32 gso *stack.GSO transHdrLen int - extraHdrLen int payloadSize int wantFragments []fragmentInfo - expectedFrags int -} - -var fragmentationTests = []fragmentationTestCase{ +}{ { description: "No Fragmentation", mtu: 1280, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1000, more: false}, @@ -2243,9 +2242,8 @@ var fragmentationTests = []fragmentationTestCase{ { description: "Fragmented", mtu: 1280, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 2000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, @@ -2255,20 +2253,18 @@ var fragmentationTests = []fragmentationTestCase{ { description: "No fragmentation with big header", mtu: 2000, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 100, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1100, more: false}, }, }, { - description: "Fragmented with gso nil", + description: "Fragmented with gso none", mtu: 1280, - gso: nil, + gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1400, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, @@ -2278,30 +2274,17 @@ var fragmentationTests = []fragmentationTestCase{ { description: "Fragmented with big header", mtu: 1280, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 100, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1200, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, {offset: 154, payloadSize: 76, more: false}, }, }, - { - description: "Fragmented with big header and prependable bytes", - mtu: 1280, - gso: &stack.GSO{}, - transHdrLen: 20, - extraHdrLen: header.IPv6MinimumSize + 66, - payloadSize: 1500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 296, more: false}, - }, - }, } -func TestFragmentation(t *testing.T) { +func TestFragmentationWritePacket(t *testing.T) { const ( ttl = 42 tos = stack.DefaultTOS @@ -2310,7 +2293,7 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) @@ -2331,10 +2314,8 @@ func TestFragmentation(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if len(ep.WrittenPackets) > 0 { - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) } }) } @@ -2368,7 +2349,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2378,7 +2359,7 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { @@ -2480,7 +2461,7 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 4d69a4de1..be61a21af 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -406,9 +406,9 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // INCOMPLETE state." - RFC 4861 section 7.2.5 case Reachable, Stale, Delay, Probe: - sameLinkAddr := e.neigh.LinkAddr == linkAddr + isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr - if !sameLinkAddr { + if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { e.dispatchChangeEventLocked(Stale) @@ -431,7 +431,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } } - if flags.Solicited && (flags.Override || sameLinkAddr) { + if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { if e.neigh.State != Reachable { e.dispatchChangeEventLocked(Reachable) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index e79abebca..3ee2a3b31 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -83,15 +83,18 @@ func eventDiffOptsWithSort() []cmp.Option { // | Reachable | Stale | Reachable timer expired | | Changed | // | Reachable | Stale | Probe or confirmation w/ different address | | Changed | // | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Stale | Stale | Override confirmation | Update LinkAddr | Changed | // | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | // | Stale | Delay | Packet sent | | Changed | // | Delay | Reachable | Upper-layer confirmation | | Changed | // | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Delay | Stale | Probe or confirmation w/ different address | | Changed | // | Delay | Probe | Delay timer expired | Send probe | Changed | // | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | // | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Probe | Stale | Probe or confirmation w/ different address | | Changed | // | Probe | Probe | Retransmit timer expired | Send probe | Changed | // | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | @@ -1370,6 +1373,77 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) @@ -1752,6 +1826,100 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) @@ -2665,6 +2833,115 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin nudDisp.mu.Unlock() } +func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8828cc5fe..dcd4319bf 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -686,7 +685,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet to forward. fwdPkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + // We need to do a deep copy of the IP packet because WritePacket (and + // friends) take ownership of the packet buffer, but we do not own it. + Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), }) // TODO(b/143425874) Decrease the TTL field in forwarded packets. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 105583c49..7f54a6de8 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -311,11 +311,25 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { } // PayloadSince returns packet payload starting from and including a particular -// header. This method isn't optimized and should be used in test only. +// header. +// +// The returned View is owned by the caller - its backing buffer is separate +// from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - var v buffer.View + size := h.pk.Data.Size() + for _, hinfo := range h.pk.headers[h.typ:] { + size += len(hinfo.buf) + } + + v := make(buffer.View, 0, size) + for _, hinfo := range h.pk.headers[h.typ:] { v = append(v, hinfo.buf...) } - return append(v, h.pk.Data.ToView()...) + + for _, view := range h.pk.Data.Views() { + v = append(v, view...) + } + + return v } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c42bb0991..d77848d61 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -111,6 +111,7 @@ var ( ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} + ErrMalformedHeader = &Error{msg: "header is malformed"} ) var messageToError map[string]*Error @@ -159,6 +160,7 @@ func StringToError(s string) *Error { ErrBroadcastDisabled, ErrNotPermitted, ErrAddressFamilyNotSupported, + ErrMalformedHeader, } messageToError = make(map[string]*Error) diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index a4f141253..34aab32d0 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -16,6 +16,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ffd38ee1a..0dcef7b04 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -178,19 +179,19 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) + routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) - if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil { + if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) } - if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil { + if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) } - if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index bf3a6f6ee..6ddcda70c 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -126,12 +127,12 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) - if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 4f2ca7f54..f1028823b 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,6 +80,7 @@ func TestPingMulticastBroadcast(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: dst, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -250,6 +251,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: dst, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), |