summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/checker.go18
-rw-r--r--pkg/tcpip/link/ethernet/BUILD15
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go99
-rw-r--r--pkg/tcpip/link/pipe/pipe.go39
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/ip_test.go414
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go147
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go302
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go18
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go254
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go69
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go61
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go6
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go277
-rw-r--r--pkg/tcpip/stack/nic.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go20
-rw-r--r--pkg/tcpip/tcpip.go2
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go13
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go7
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go2
21 files changed, 1541 insertions, 229 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index d4d785cca..6f81b0164 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker {
}
}
+// IPPayload creates a checker that checks the payload.
+func IPPayload(payload []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ got := h[0].Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(got) == 0 && len(payload) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(payload, got); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// IPv4Options returns a checker that checks the options in an IPv4 packet.
func IPv4Options(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD
new file mode 100644
index 000000000..ec92ed623
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ethernet",
+ srcs = ["ethernet.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
new file mode 100644
index 000000000..3eef7cd56
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ethernet provides an implementation of an ethernet link endpoint that
+// wraps an inner link endpoint.
+package ethernet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns an ethernet link endpoint that wraps an inner link endpoint.
+func New(ep stack.LinkEndpoint) *Endpoint {
+ var e Endpoint
+ e.Endpoint.Init(ep, &e)
+ return &e
+}
+
+// Endpoint is an ethernet endpoint.
+//
+// It adds an ethernet header to packets before sending them out through its
+// inner link endpoint and consumes an ethernet header before sending the
+// packet to the stack.
+type Endpoint struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ return
+ }
+
+ eth := header.Ethernet(hdr)
+ if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt)
+ }
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ return e.Endpoint.WritePacket(r, gso, proto, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ linkAddr := e.Endpoint.LinkAddress()
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ fields := header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: proto,
+ }
+ eth.Encode(&fields)
+}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 76f563811..523b0d24b 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -26,27 +26,23 @@ import (
var _ stack.LinkEndpoint = (*Endpoint)(nil)
// New returns both ends of a new pipe.
-func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) {
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) {
ep1 := &Endpoint{
- linkAddr: linkAddr1,
- capabilities: capabilities,
+ linkAddr: linkAddr1,
}
ep2 := &Endpoint{
- linkAddr: linkAddr2,
- linked: ep1,
- capabilities: capabilities,
+ linkAddr: linkAddr2,
}
ep1.linked = ep2
+ ep2.linked = ep1
return ep1, ep2
}
// Endpoint is one end of a pipe.
type Endpoint struct {
- capabilities stack.LinkEndpointCapabilities
- linkAddr tcpip.LinkAddress
- dispatcher stack.NetworkDispatcher
- linked *Endpoint
- onWritePacket func(*stack.PacketBuffer)
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ linkAddr tcpip.LinkAddress
}
// WritePacket implements stack.LinkEndpoint.
@@ -55,16 +51,11 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network
return nil
}
- // The pipe endpoint will accept all multicast/broadcast link traffic and only
- // unicast traffic destined to itself.
- if len(e.linked.linkAddr) != 0 &&
- r.RemoteLinkAddress != e.linked.linkAddr &&
- r.RemoteLinkAddress != header.EthernetBroadcastAddress &&
- !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) {
- return nil
- }
-
- e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // Note that the local address from the perspective of this endpoint is the
+ // remote address from the perspective of the other end of the pipe
+ // (e.linked). Similarly, the remote address from the perspective of this
+ // endpoint is the local address on the other end.
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
@@ -100,8 +91,8 @@ func (*Endpoint) MTU() uint32 {
}
// Capabilities implements stack.LinkEndpoint.
-func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.capabilities
+func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
}
// MaxHeaderLength implements stack.LinkEndpoint.
@@ -116,7 +107,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// ARPHardwareType implements stack.LinkEndpoint.
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareEther
+ return header.ARPHardwareNone
}
// AddHeader implements stack.LinkEndpoint.
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 59710352b..c118a2929 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -12,6 +12,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d436873b6..f20b94d97 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,11 +15,13 @@
package ip_test
import (
+ "strings"
"testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -320,6 +322,7 @@ func TestSourceAddressValidation(t *testing.T) {
SrcAddr: src,
DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -342,7 +345,6 @@ func TestSourceAddressValidation(t *testing.T) {
SrcAddr: src,
DstAddr: localIPv6Addr,
})
-
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
@@ -579,6 +581,7 @@ func TestIPv4Receive(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -660,6 +663,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
SrcAddr: "\x0a\x00\x00\xbb",
DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Create the ICMP header.
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
@@ -679,6 +683,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
SrcAddr: localIPv4Addr,
DstAddr: remoteIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
@@ -732,6 +737,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: localIPv4Addr,
})
+ ip1.SetChecksum(^ip1.CalculateChecksum())
+
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag1[i] = uint8(i)
@@ -748,6 +755,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: localIPv4Addr,
})
+ ip2.SetChecksum(^ip2.CalculateChecksum())
+
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag2[i] = uint8(i)
@@ -1020,3 +1029,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer
_, _ = pkt.NetworkHeader().Consume(netHdrLen)
return pkt
}
+
+func TestWriteHeaderIncludedPacket(t *testing.T) {
+ const (
+ nicID = 1
+ transportProto = 5
+
+ dataLen = 4
+ optionsLen = 4
+ )
+
+ dataBuf := [dataLen]byte{1, 2, 3, 4}
+ data := dataBuf[:]
+
+ ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1}
+ ipv4Options := ipv4OptionsBuf[:]
+
+ ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
+ ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
+
+ var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
+ ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
+ if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ nicAddr tcpip.Address
+ remoteAddr tcpip.Address
+ pktGen func(*testing.T, tcpip.Address) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
+ expectedErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with IHL too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 1,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 minimum size",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(header.IPv4MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ totalLen := ipHdrLen + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(ipHdrLen))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHdrLen),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
+ }
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6 with extension header",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
+ checker.IPPayload(ipv6PayloadWithExtHdr),
+ )
+ },
+ },
+ {
+ name: "IPv6 minimum size",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(header.IPv6MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv6 too small",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ srcAddr tcpip.Address
+ }{
+ {
+ name: "unspecified source",
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ },
+ {
+ name: "random source",
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ e := channel.New(1, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
+
+ r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ }
+ defer r.Release()
+
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ })); err != test.expectedErr {
+ t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ test.checker(t, pkt.Pkt, subTest.srcAddr)
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c5ac7b8b5..e7c58ae0a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -190,29 +190,6 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments fragments pkt and writes the results on the link
-// endpoint. The IP header must already present in the original packet. The mtu
-// is the maximum size of the packets.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
- networkHeader := header.IPv4(pkt.NetworkHeader().View())
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
-
- for {
- fragPkt, more := buildNextFragment(&pf, networkHeader)
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- if !more {
- break
- }
- }
-
- return nil
-}
-
func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
length := uint16(pkt.Size())
@@ -234,10 +211,39 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
+func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
+ return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU())
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The mtu is the maximum size of the packets.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ return n, pf.RemainingFragmentCount(), nil
+ }
+ }
+}
+
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -273,8 +279,18 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
+
+ if e.packetMustBeFragmented(pkt, gso) {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -293,9 +309,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
- for pkt := pkts.Front(); pkt != nil; {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.addIPHeader(r, pkt, params)
- pkt = pkt.Next()
+ if e.packetMustBeFragmented(pkt, gso) {
+ // Keep track of the packet that is about to be fragmented so it can be
+ // removed once the fragmentation is done.
+ originalPkt := pkt
+ if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(pkt, fragPkt)
+ pkt = fragPkt
+ return nil
+ }); err != nil {
+ panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err))
+ }
+ // Remove the packet that was just fragmented and process the rest.
+ pkts.Remove(originalPkt)
+ }
}
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -347,30 +377,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacket writes a packet already containing a network
-// header through the given route.
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrMalformedHeader
}
ip := header.IPv4(h)
- if !ip.IsValid(pkt.Data.Size()) {
- return tcpip.ErrInvalidOptionValue
- }
// Always set the total length.
- ip.SetTotalLength(uint16(pkt.Data.Size()))
+ pktSize := pkt.Data.Size()
+ ip.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ if ip.SourceAddress() == header.IPv4Any {
ip.SetSourceAddress(r.LocalAddress)
}
- // Set the destination. If the packet already included a destination,
- // it will be part of the route.
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
ip.SetDestinationAddress(r.RemoteAddress)
// Set the packet ID when zero.
@@ -387,19 +414,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if r.Loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, pkt.Clone())
- }
- if r.Loop&stack.PacketOut == 0 {
- return nil
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
}
- if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- return nil
+ return e.writePacket(r, nil /* gso */, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -415,6 +440,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if h.CalculateChecksum() != 0xffff {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+
// As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 9916d783f..fee11bb38 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -15,9 +15,9 @@
package ipv4_test
import (
- "bytes"
"context"
"encoding/hex"
+ "fmt"
"math"
"net"
"testing"
@@ -39,6 +39,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const extraHeaderReserve = 50
+
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
@@ -118,6 +120,7 @@ func TestIPv4Sanity(t *testing.T) {
tests := []struct {
name string
headerLength uint8 // value of 0 means "use correct size"
+ badHeaderChecksum bool
maxTotalLength uint16
transportProtocol uint8
TTL uint8
@@ -133,6 +136,14 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
},
+ {
+ name: "bad header checksum",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ badHeaderChecksum: true,
+ shouldFail: true,
+ },
// The TTL tests check that we are not rejecting an incoming packet
// with a zero or one TTL, which has been a point of confusion in the
// past as RFC 791 says: "If this field contains the value zero, then the
@@ -243,7 +254,7 @@ func TestIPv4Sanity(t *testing.T) {
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv4EmptySubnet,
NIC: nicID,
},
@@ -288,6 +299,12 @@ func TestIPv4Sanity(t *testing.T) {
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
}
+ ip.SetChecksum(0)
+ ipHeaderChecksum := ip.CalculateChecksum()
+ if test.badHeaderChecksum {
+ ipHeaderChecksum += 42
+ }
+ ip.SetChecksum(^ipHeaderChecksum)
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
@@ -369,11 +386,10 @@ func TestIPv4Sanity(t *testing.T) {
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
- t.Helper()
- // Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize])
- vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views())
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // Make a complete array of the sourcePacket packet.
+ source := header.IPv4(packets[0].NetworkHeader().View())
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
source = append(source, vv.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
@@ -382,82 +398,147 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
sourceCopy.SetChecksum(0)
sourceCopy.SetFlagsFragmentOffset(0, 0)
sourceCopy.SetTotalLength(0)
- var offset uint16
// Build up an array of the bytes sent.
- var reassembledPayload []byte
+ var reassembledPayload buffer.VectorisedView
for i, packet := range packets {
// Confirm that the packet is valid.
allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
- ip := header.IPv4(allBytes.ToView())
- if !ip.IsValid(len(ip)) {
- t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
+ fragmentIPHeader := header.IPv4(allBytes.ToView())
+ if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
}
- if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want {
- t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want)
+ if got := len(fragmentIPHeader); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
}
- if got, want := len(ip), int(mtu); got > want {
- t.Errorf("fragment is too large, got %d want %d", got, want)
+ if got := fragmentIPHeader.TransportProtocol(); got != proto {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
}
- if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
- t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
}
- if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want {
- t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want)
+ if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
}
- if i < len(packets)-1 {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
+ return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
+ }
+ if wantFragments[i].more {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset)
} else {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset)
}
- reassembledPayload = append(reassembledPayload, ip.Payload()...)
- offset += ip.TotalLength() - uint16(ip.HeaderLength())
+ reassembledPayload.AppendView(packet.TransportHeader().View())
+ reassembledPayload.Append(packet.Data)
// Clear out the checksum and length from the ip because we can't compare
// it.
- sourceCopy.SetTotalLength(uint16(len(ip)))
+ sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
- if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) {
- t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()]))
+ if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
}
- expected := source[source.HeaderLength():]
- if !bytes.Equal(reassembledPayload, expected) {
- t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected))
+
+ expected := buffer.View(source[source.HeaderLength():])
+ if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
+
+ return nil
}
-func TestFragmentation(t *testing.T) {
- const ttl = 42
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
- var manyPayloadViewsSizes [1000]int
- for i := range manyPayloadViewsSizes {
- manyPayloadViewsSizes[i] = 7
- }
- fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- transportHeaderLength int
- extraHeaderReserveLength int
- payloadViewsSizes []int
- expectedFrags int
- }{
- {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
- {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
- }
+var fragmentationTests = []struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ payloadSize int
+ wantFragments []fragmentInfo
+}{
+ {
+ description: "No Fragmentation",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 744, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: nil,
+ transportHeaderLength: 100,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso none",
+ mtu: 1280,
+ gso: &stack.GSO{Type: stack.GSONone},
+ transportHeaderLength: 0,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 144, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 100,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 44, more: false},
+ },
+ },
+ {
+ description: "Fragmented with MTU smaller than header",
+ mtu: 300,
+ gso: nil,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 280, more: true},
+ {offset: 280, payloadSize: 280, more: true},
+ {offset: 560, payloadSize: 280, more: true},
+ {offset: 840, payloadSize: 280, more: true},
+ {offset: 1120, payloadSize: 280, more: true},
+ {offset: 1400, payloadSize: 100, more: false},
+ },
+ },
+}
- for _, ft := range fragTests {
+func TestFragmentationWritePacket(t *testing.T) {
+ const ttl = 42
+
+ for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -467,17 +548,101 @@ func TestFragmentation(t *testing.T) {
if err != nil {
t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
-
- if got := len(ep.WrittenPackets); got != ft.expectedFrags {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
}
- if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
}
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- compareFragments(t, ep.WrittenPackets, source, ft.mtu)
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ writePacketsTests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
+
+ for _, test := range writePacketsTests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err)
+ }
+ if n != wantTotalPackets {
+ t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
})
}
}
@@ -534,14 +699,14 @@ func TestFragmentationErrors(t *testing.T) {
t.Run(ft.description, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != expectedError {
- t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError)
}
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
@@ -1277,6 +1442,7 @@ func TestReceiveFragments(t *testing.T) {
SrcAddr: frag.srcAddr,
DstAddr: frag.dstAddr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
vv := hdr.View().ToVectorisedView()
vv.AppendView(frag.payload)
@@ -1545,6 +1711,7 @@ func TestPacketQueing(t *testing.T) {
SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
@@ -1588,6 +1755,7 @@ func TestPacketQueing(t *testing.T) {
SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
@@ -1633,7 +1801,7 @@ func TestPacketQueing(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
NIC: nicID,
},
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 913b2140c..ead6bedcb 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -439,19 +439,19 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the NA message has the target link layer option, update the link
// address cache with the link address for the target of the message.
- if len(targetLinkAddr) != 0 {
- if e.nud == nil {
+ if e.nud == nil {
+ if len(targetLinkAddr) != 0 {
e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
- return
}
-
- e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
- Solicited: na.SolicitedFlag(),
- Override: na.OverrideFlag(),
- IsRouter: na.RouterFlag(),
- })
+ return
}
+ e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{
+ Solicited: na.SolicitedFlag(),
+ Override: na.OverrideFlag(),
+ IsRouter: na.RouterFlag(),
+ })
+
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 3affcc4e4..8dc33c560 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -101,14 +101,19 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco
func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
}
-type stubNUDHandler struct{}
+type stubNUDHandler struct {
+ probeCount int
+ confirmationCount int
+}
var _ stack.NUDHandler = (*stubNUDHandler)(nil)
-func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+ s.probeCount++
}
-func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+ s.confirmationCount++
}
func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
@@ -118,6 +123,12 @@ var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
stack.NetworkLinkEndpoint
+
+ linkAddr tcpip.LinkAddress
+}
+
+func (i *testInterface) LinkAddress() tcpip.LinkAddress {
+ return i.linkAddr
}
func (*testInterface) ID() tcpip.NICID {
@@ -1492,3 +1503,240 @@ func TestPacketQueing(t *testing.T) {
})
}
}
+
+func TestCallsToNeighborCache(t *testing.T) {
+ tests := []struct {
+ name string
+ createPacket func() header.ICMPv6
+ multicast bool
+ source tcpip.Address
+ destination tcpip.Address
+ wantProbeCount int
+ wantConfirmationCount int
+ }{
+ {
+ name: "Unicast Neighbor Solicitation without source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ // "The source link-layer address option SHOULD be included in unicast
+ // solicitations." - RFC 4861 section 4.3
+ //
+ // A Neighbor Advertisement needs to be sent in response, but the
+ // Neighbor Cache shouldn't be updated since we have no useful
+ // information about the sender.
+ wantProbeCount: 0,
+ },
+ {
+ name: "Unicast Neighbor Solicitation with source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ wantProbeCount: 1,
+ },
+ {
+ name: "Multicast Neighbor Solicitation without source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.SolicitedNodeAddr(lladdr0),
+ // "The source link-layer address option MUST be included in multicast
+ // solicitations." - RFC 4861 section 4.3
+ wantProbeCount: 0,
+ },
+ {
+ name: "Multicast Neighbor Solicitation with source link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(nsSize))
+ icmp.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ ns.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.SolicitedNodeAddr(lladdr0),
+ wantProbeCount: 1,
+ },
+ {
+ name: "Unicast Neighbor Advertisement without target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ // "When responding to unicast solicitations, the target link-layer
+ // address option can be omitted since the sender of the solicitation has
+ // the correct link-layer address; otherwise, it would not be able to
+ // send the unicast solicitation in the first place."
+ // - RFC 4861 section 4.4
+ wantConfirmationCount: 1,
+ },
+ {
+ name: "Unicast Neighbor Advertisement with target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: lladdr0,
+ wantConfirmationCount: 1,
+ },
+ {
+ name: "Multicast Neighbor Advertisement without target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(false)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.IPv6AllNodesMulticastAddress,
+ // "Target link-layer address MUST be included for multicast solicitations
+ // in order to avoid infinite Neighbor Solicitation "recursion" when the
+ // peer node does not have a cache entry to return a Neighbor
+ // Advertisements message." - RFC 4861 section 4.4
+ wantConfirmationCount: 0,
+ },
+ {
+ name: "Multicast Neighbor Advertisement with target link-layer address option",
+ createPacket: func() header.ICMPv6 {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ icmp := header.ICMPv6(buffer.NewView(naSize))
+ icmp.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ na.SetSolicitedFlag(false)
+ na.SetOverrideFlag(false)
+ na.SetTargetAddress(lladdr1)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+ return icmp
+ },
+ source: lladdr1,
+ destination: header.IPv6AllNodesMulticastAddress,
+ wantConfirmationCount: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
+ UseNeighborCache: true,
+ })
+ {
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_, _) = %s", err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err)
+ }
+ }
+ {
+ subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: nicID,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+ nudHandler := &stubNUDHandler{}
+ ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+ defer ep.Close()
+
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
+ r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
+ }
+ defer r.Release()
+
+ // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch.
+ r.LocalAddress = test.destination
+
+ icmp := test.createPacket()
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{}))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(icmp)),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: r.RemoteAddress,
+ DstAddr: r.LocalAddress,
+ })
+ ep.HandlePacket(&r, pkt)
+
+ // Confirm the endpoint calls the correct NUDHandler method.
+ if nudHandler.probeCount != test.wantProbeCount {
+ t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount)
+ }
+ if nudHandler.confirmationCount != test.wantConfirmationCount {
+ t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 2bd8f4ece..9670696c7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -387,7 +387,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
}
func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+ return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU())
}
// handleFragments fragments pkt and calls the handler function on each
@@ -416,17 +416,18 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p
}
n++
if !more {
- break
+ return n, pf.RemainingFragmentCount(), nil
}
}
-
- return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt, params.Protocol)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -468,7 +469,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -501,21 +502,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
if e.packetMustBeFragmented(pb, gso) {
- current := pb
- _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Keep track of the packet that is about to be fragmented so it can be
+ // removed once the fragmentation is done.
+ originalPkt := pb
+ if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
- pkts.InsertAfter(current, fragPkt)
- current = current.Next()
+ pkts.InsertAfter(pb, fragPkt)
+ pb = fragPkt
return nil
- })
- if err != nil {
+ }); err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
- // The fragmented packet can be released. The rest of the packets can be
- // processed.
- pkts.Remove(pb)
- pb = current
+ // Remove the packet that was just fragmented and process the rest.
+ pkts.Remove(originalPkt)
}
}
@@ -569,11 +569,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
-// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
- // TODO(b/146666412): Support IPv6 header-included packets.
- return tcpip.ErrNotSupported
+// WriteHeaderIncludedPacker implements stack.NetworkEndpoint.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required checks.
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
+ ip := header.IPv6(h)
+
+ // Always set the payload length.
+ pktSize := pkt.Data.Size()
+ ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == header.IPv6Any {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ proto, _, _, _, ok := parse.IPv6(pkt)
+ if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return e.writePacket(r, nil /* gso */, pkt, proto)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index bee18d1a8..297868f24 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -49,6 +49,8 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+
+ extraHeaderReserve = 50
)
// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
@@ -181,6 +183,9 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
+ if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
+ }
if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
}
@@ -208,8 +213,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
reassembledPayload.Append(fragment.Data)
}
- result := reassembledPayload.ToView()
- if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" {
return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
@@ -2217,24 +2221,19 @@ type fragmentInfo struct {
payloadSize uint16
}
-type fragmentationTestCase struct {
+var fragmentationTests = []struct {
description string
mtu uint32
gso *stack.GSO
transHdrLen int
- extraHdrLen int
payloadSize int
wantFragments []fragmentInfo
- expectedFrags int
-}
-
-var fragmentationTests = []fragmentationTestCase{
+}{
{
description: "No Fragmentation",
mtu: 1280,
- gso: &stack.GSO{},
+ gso: nil,
transHdrLen: 0,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1000,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1000, more: false},
@@ -2243,9 +2242,8 @@ var fragmentationTests = []fragmentationTestCase{
{
description: "Fragmented",
mtu: 1280,
- gso: &stack.GSO{},
+ gso: nil,
transHdrLen: 0,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 2000,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1240, more: true},
@@ -2255,20 +2253,18 @@ var fragmentationTests = []fragmentationTestCase{
{
description: "No fragmentation with big header",
mtu: 2000,
- gso: &stack.GSO{},
+ gso: nil,
transHdrLen: 100,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1000,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1100, more: false},
},
},
{
- description: "Fragmented with gso nil",
+ description: "Fragmented with gso none",
mtu: 1280,
- gso: nil,
+ gso: &stack.GSO{Type: stack.GSONone},
transHdrLen: 0,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1400,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1240, more: true},
@@ -2278,30 +2274,17 @@ var fragmentationTests = []fragmentationTestCase{
{
description: "Fragmented with big header",
mtu: 1280,
- gso: &stack.GSO{},
+ gso: nil,
transHdrLen: 100,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1200,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1240, more: true},
{offset: 154, payloadSize: 76, more: false},
},
},
- {
- description: "Fragmented with big header and prependable bytes",
- mtu: 1280,
- gso: &stack.GSO{},
- transHdrLen: 20,
- extraHdrLen: header.IPv6MinimumSize + 66,
- payloadSize: 1500,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 296, more: false},
- },
- },
}
-func TestFragmentation(t *testing.T) {
+func TestFragmentationWritePacket(t *testing.T) {
const (
ttl = 42
tos = stack.DefaultTOS
@@ -2310,7 +2293,7 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt.Clone()
ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
@@ -2331,10 +2314,8 @@ func TestFragmentation(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- if len(ep.WrittenPackets) > 0 {
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
}
})
}
@@ -2368,7 +2349,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -2378,7 +2359,7 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
@@ -2480,7 +2461,7 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 4d69a4de1..be61a21af 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -406,9 +406,9 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
// INCOMPLETE state." - RFC 4861 section 7.2.5
case Reachable, Stale, Delay, Probe:
- sameLinkAddr := e.neigh.LinkAddr == linkAddr
+ isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr
- if !sameLinkAddr {
+ if isLinkAddrDifferent {
if !flags.Override {
if e.neigh.State == Reachable {
e.dispatchChangeEventLocked(Stale)
@@ -431,7 +431,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
}
}
- if flags.Solicited && (flags.Override || sameLinkAddr) {
+ if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
if e.neigh.State != Reachable {
e.dispatchChangeEventLocked(Reachable)
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index e79abebca..3ee2a3b31 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -83,15 +83,18 @@ func eventDiffOptsWithSort() []cmp.Option {
// | Reachable | Stale | Reachable timer expired | | Changed |
// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
// | Stale | Delay | Packet sent | | Changed |
// | Delay | Reachable | Upper-layer confirmation | | Changed |
// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
// | Delay | Probe | Delay timer expired | Send probe | Changed |
// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
+// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed |
// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
@@ -1370,6 +1373,77 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, _ := entryTestSetup(c)
@@ -1752,6 +1826,100 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
nudDisp.mu.Unlock()
}
+func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 1
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
e, nudDisp, linkRes, _ := entryTestSetup(c)
@@ -2665,6 +2833,115 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
nudDisp.mu.Unlock()
}
+func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.Advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ }
+ e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
+ clock.Advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
func TestEntryProbeToFailed(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8828cc5fe..dcd4319bf 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -686,7 +685,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// packet to forward.
fwdPkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ // We need to do a deep copy of the IP packet because WritePacket (and
+ // friends) take ownership of the packet buffer, but we do not own it.
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
})
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 105583c49..7f54a6de8 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -311,11 +311,25 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
}
// PayloadSince returns packet payload starting from and including a particular
-// header. This method isn't optimized and should be used in test only.
+// header.
+//
+// The returned View is owned by the caller - its backing buffer is separate
+// from the packet header's underlying packet buffer.
func PayloadSince(h PacketHeader) buffer.View {
- var v buffer.View
+ size := h.pk.Data.Size()
+ for _, hinfo := range h.pk.headers[h.typ:] {
+ size += len(hinfo.buf)
+ }
+
+ v := make(buffer.View, 0, size)
+
for _, hinfo := range h.pk.headers[h.typ:] {
v = append(v, hinfo.buf...)
}
- return append(v, h.pk.Data.ToView()...)
+
+ for _, view := range h.pk.Data.Views() {
+ v = append(v, view...)
+ }
+
+ return v
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c42bb0991..d77848d61 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -111,6 +111,7 @@ var (
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
+ ErrMalformedHeader = &Error{msg: "header is malformed"}
)
var messageToError map[string]*Error
@@ -159,6 +160,7 @@ func StringToError(s string) *Error {
ErrBroadcastDisabled,
ErrNotPermitted,
ErrAddressFamilyNotSupported,
+ ErrMalformedHeader,
}
messageToError = make(map[string]*Error)
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index a4f141253..34aab32d0 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -16,6 +16,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index ffd38ee1a..0dcef7b04 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -21,6 +21,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -178,19 +179,19 @@ func TestForwarding(t *testing.T) {
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired)
- routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+ host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr)
+ routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr)
- if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil {
+ if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
}
- if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil {
+ if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
}
- if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index bf3a6f6ee..6ddcda70c 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -126,12 +127,12 @@ func TestPing(t *testing.T) {
host1Stack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+ host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr)
- if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 4f2ca7f54..f1028823b 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -80,6 +80,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: dst,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -250,6 +251,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: dst,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),