summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/arp.go107
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD5
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go129
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go252
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go12
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go4
-rw-r--r--pkg/tcpip/network/ip_test.go512
-rw-r--r--pkg/tcpip/network/ipv4/BUILD4
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go150
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go507
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go583
-rw-r--r--pkg/tcpip/network/ipv6/BUILD5
-rw-r--r--pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go40
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go242
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go295
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go1093
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go687
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go2013
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go95
-rw-r--r--pkg/tcpip/network/testutil/BUILD1
21 files changed, 6112 insertions, 626 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 46083925c..59710352b 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -9,6 +9,7 @@ go_test(
"ip_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -17,6 +18,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
],
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b025bb087..7df77c66e 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -18,6 +18,8 @@
package arp
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,38 +35,73 @@ const (
ProtocolAddress = tcpip.Address("arp")
)
-// endpoint implements stack.NetworkEndpoint.
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- protocol *protocol
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ stack.AddressableEndpointState
+
+ protocol *protocol
+
+ // enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
}
+func (e *endpoint) Enable() *tcpip.Error {
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ e.setEnabled(true)
+ return nil
+}
+
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+func (e *endpoint) setEnabled(v bool) {
+ if v {
+ atomic.StoreUint32(&e.enabled, 1)
+ } else {
+ atomic.StoreUint32(&e.enabled, 0)
+ }
+}
+
+func (e *endpoint) Disable() {
+ e.setEnabled(false)
+}
+
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
func (e *endpoint) DefaultTTL() uint8 {
return 0
}
func (e *endpoint) MTU() uint32 {
- lmtu := e.linkEP.MTU()
+ lmtu := e.nic.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.ARPSize
+ return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.AddressableEndpointState.Cleanup()
+}
func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
@@ -85,6 +122,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
}
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
return
@@ -95,15 +136,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -112,24 +153,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ origSender := h.HardwareAddressSender()
+ r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
+ respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
- packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.HardwareAddressTarget(), origSender)
copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
return
}
@@ -161,14 +210,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
protocol: p,
- nicID: nicID,
- linkEP: sender,
+ nic: nic,
linkAddrCache: linkAddrCache,
nud: nud,
}
+ e.AddressableEndpointState.Init(e)
+ return e
}
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
@@ -179,6 +229,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
+ NetProto: ProtocolNumber,
RemoteLinkAddress: remoteLinkAddr,
}
if len(r.RemoteLinkAddress) == 0 {
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 96c5f42f8..47fb63290 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -29,6 +29,8 @@ go_library(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
],
)
@@ -43,5 +45,8 @@ go_test(
library = ":fragmentation",
deps = [
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
+ "//pkg/tcpip/network/testutil",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6a4843f92..ed502a473 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -13,7 +13,7 @@
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
-// It is based on RFC 791 and RFC 815.
+// It is based on RFC 791, RFC 815 and RFC 8200.
package fragmentation
import (
@@ -25,12 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
- // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
- DefaultReassembleTimeout = 30 * time.Second
-
// HighFragThreshold is the threshold at which we start trimming old
// fragmented packets. Linux uses a default value of 4 MB. See
// net.ipv4.ipfrag_high_thresh for more information.
@@ -81,6 +79,8 @@ type Fragmentation struct {
size int
timeout time.Duration
blockSize uint16
+ clock tcpip.Clock
+ releaseJob *tcpip.Job
}
// NewFragmentation creates a new Fragmentation.
@@ -97,7 +97,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -110,13 +110,17 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
blockSize = minBlockSize
}
- return &Fragmentation{
+ f := &Fragmentation{
reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
blockSize: blockSize,
+ clock: clock,
}
+ f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
+
+ return f
}
// Process processes an incoming fragment belonging to an ID and returns a
@@ -155,15 +159,17 @@ func (f *Fragmentation) Process(
f.mu.Lock()
r, ok := f.reassemblers[id]
- if ok && r.tooOld(f.timeout) {
- // This is very likely to be an id-collision or someone performing a slow-rate attack.
- f.release(r)
- ok = false
- }
if !ok {
- r = newReassembler(id)
+ r = newReassembler(id, f.clock)
f.reassemblers[id] = r
+ wasEmpty := f.rList.Empty()
f.rList.PushFront(r)
+ if wasEmpty {
+ // If we have just pushed a first reassembler into an empty list, we
+ // should kickstart the release job. The release job will keep
+ // rescheduling itself until the list becomes empty.
+ f.releaseReassemblersLocked()
+ }
}
f.mu.Unlock()
@@ -211,3 +217,102 @@ func (f *Fragmentation) release(r *reassembler) {
f.size = 0
}
}
+
+// releaseReassemblersLocked releases already-expired reassemblers, then
+// schedules the job to call back itself for the remaining reassemblers if
+// any. This function must be called with f.mu locked.
+func (f *Fragmentation) releaseReassemblersLocked() {
+ now := f.clock.NowMonotonic()
+ for {
+ // The reassembler at the end of the list is the oldest.
+ r := f.rList.Back()
+ if r == nil {
+ // The list is empty.
+ break
+ }
+ elapsed := time.Duration(now-r.creationTime) * time.Nanosecond
+ if f.timeout > elapsed {
+ // If the oldest reassembler has not expired, schedule the release
+ // job so that this function is called back when it has expired.
+ f.releaseJob.Schedule(f.timeout - elapsed)
+ break
+ }
+ // If the oldest reassembler has already expired, release it.
+ f.release(r)
+ }
+}
+
+// PacketFragmenter is the book-keeping struct for packet fragmentation.
+type PacketFragmenter struct {
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ innerMTU int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
+}
+
+// MakePacketFragmenter prepares the struct needed for packet fragmentation.
+//
+// pkt is the packet to be fragmented.
+//
+// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// have.
+//
+// reserve is the number of bytes that should be reserved for the headers in
+// each generated fragment.
+func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+ // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
+ // repeated in each fragment. However we do not currently support any header
+ // of that kind yet, so the following computation is valid for both IPv4 and
+ // IPv6.
+ // TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
+ // supported for outbound packets, the fragmentable data should not include
+ // these headers.
+ var fragmentableData buffer.VectorisedView
+ fragmentableData.AppendView(pkt.TransportHeader().View())
+ fragmentableData.Append(pkt.Data)
+ fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+
+ return PacketFragmenter{
+ data: fragmentableData,
+ reserve: reserve,
+ innerMTU: innerMTU,
+ fragmentCount: fragmentCount,
+ }
+}
+
+// BuildNextFragment returns a packet with the payload of the next fragment,
+// along with the fragment's offset, the number of bytes copied and a boolean
+// indicating if there are more fragments left or not. If this function is
+// called again after it indicated that no more fragments were left, it will
+// panic.
+//
+// Note that the returned packet will not have its network and link headers
+// populated, but space for them will be reserved. The transport header will be
+// stored in the packet's data.
+func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
+ if pf.currentFragment >= pf.fragmentCount {
+ panic("BuildNextFragment should not be called again after the last fragment was returned")
+ }
+
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: pf.reserve,
+ })
+
+ // Copy data for the fragment.
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+
+ offset := pf.fragmentOffset
+ pf.fragmentOffset += copied
+ pf.currentFragment++
+ more := pf.currentFragment != pf.fragmentCount
+
+ return fragPkt, offset, copied, more
+}
+
+// RemainingFragmentCount returns the number of fragments left to be built.
+func (pf *PacketFragmenter) RemainingFragmentCount() int {
+ return pf.fragmentCount - pf.currentFragment
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 416604659..d3c7d7f92 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -20,9 +20,16 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
)
+// reassembleTimeout is dummy timeout used for testing, where the clock never
+// advances.
+const reassembleTimeout = 1
+
// vv is a helper to build VectorisedView from different strings.
func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
@@ -95,7 +102,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
@@ -131,25 +138,126 @@ func TestFragmentationProcess(t *testing.T) {
}
func TestReassemblingTimeout(t *testing.T) {
- timeout := time.Millisecond
- f := NewFragmentation(minBlockSize, 1024, 512, timeout)
- // Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
- // Sleep more than the timeout.
- time.Sleep(2 * timeout)
- // Send another fragment that completes a packet.
- // However, no packet should be reassembled because the fragment arrived after the timeout.
- _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1"))
- if err != nil {
- t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err)
+ const (
+ reassemblyTimeout = time.Millisecond
+ protocol = 0xff
+ )
+
+ type fragment struct {
+ first uint16
+ last uint16
+ more bool
+ data string
}
- if done {
- t.Errorf("Fragmentation does not respect the reassembling timeout.")
+
+ type event struct {
+ // name is a nickname of this event.
+ name string
+
+ // clockAdvance is a duration to advance the clock. The clock advances
+ // before a fragment specified in the fragment field is processed.
+ clockAdvance time.Duration
+
+ // fragment is a fragment to process. This can be nil if there is no
+ // fragment to process.
+ fragment *fragment
+
+ // expectDone is true if the fragmentation instance should report the
+ // reassembly is done after the fragment is processd.
+ expectDone bool
+
+ // sizeAfterEvent is the expected size of the fragmentation instance after
+ // the event.
+ sizeAfterEvent int
+ }
+
+ half1 := &fragment{first: 0, last: 0, more: true, data: "0"}
+ half2 := &fragment{first: 1, last: 1, more: false, data: "1"}
+
+ tests := []struct {
+ name string
+ events []event
+ }{
+ {
+ name: "half1 and half2 are reassembled successfully",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: true,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ {
+ name: "half1 timeout, half2 timeout",
+ events: []event{
+ {
+ name: "half1",
+ fragment: half1,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half1 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ {
+ name: "half2",
+ fragment: half2,
+ expectDone: false,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 just before reassembly timeout",
+ clockAdvance: reassemblyTimeout - 1,
+ sizeAfterEvent: 1,
+ },
+ {
+ name: "half2 reassembly timeout",
+ clockAdvance: 1,
+ sizeAfterEvent: 0,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock)
+ for _, event := range test.events {
+ clock.Advance(event.clockAdvance)
+ if frag := event.fragment; frag != nil {
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ if err != nil {
+ t.Fatalf("%s: f.Process failed: %s", event.name, err)
+ }
+ if done != event.expectDone {
+ t.Fatalf("%s: got done = %t, want = %t", event.name, done, event.expectDone)
+ }
+ }
+ if got, want := f.size, event.sizeAfterEvent; got != want {
+ t.Errorf("%s: got f.size = %d, want = %d", event.name, got, want)
+ }
+ }
+ })
}
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
// Send first fragment with id = 1.
@@ -173,7 +281,7 @@ func TestMemoryLimits(t *testing.T) {
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
// Send the same packet again.
@@ -268,7 +376,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout)
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
_, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
@@ -279,3 +387,113 @@ func TestErrors(t *testing.T) {
})
}
}
+
+type fragmentInfo struct {
+ remaining int
+ copied int
+ offset int
+ more bool
+}
+
+func TestPacketFragmenter(t *testing.T) {
+ const (
+ reserve = 60
+ proto = 0
+ )
+
+ tests := []struct {
+ name string
+ innerMTU int
+ transportHeaderLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ }{
+ {
+ name: "Packet exactly fits in MTU",
+ innerMTU: 1280,
+ transportHeaderLen: 0,
+ payloadSize: 1280,
+ wantFragments: []fragmentInfo{
+ {remaining: 0, copied: 1280, offset: 0, more: false},
+ },
+ },
+ {
+ name: "Packet exactly does not fit in MTU",
+ innerMTU: 1000,
+ transportHeaderLen: 0,
+ payloadSize: 1001,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 1000, offset: 0, more: true},
+ {remaining: 0, copied: 1, offset: 1000, more: false},
+ },
+ },
+ {
+ name: "Packet has a transport header",
+ innerMTU: 560,
+ transportHeaderLen: 40,
+ payloadSize: 560,
+ wantFragments: []fragmentInfo{
+ {remaining: 1, copied: 560, offset: 0, more: true},
+ {remaining: 0, copied: 40, offset: 560, more: false},
+ },
+ },
+ {
+ name: "Packet has a huge transport header",
+ innerMTU: 500,
+ transportHeaderLen: 1300,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {remaining: 3, copied: 500, offset: 0, more: true},
+ {remaining: 2, copied: 500, offset: 500, more: true},
+ {remaining: 1, copied: 500, offset: 1000, more: true},
+ {remaining: 0, copied: 300, offset: 1500, more: false},
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(test.transportHeaderLen, reserve, []int{test.payloadSize}, proto)
+ var originalPayload buffer.VectorisedView
+ originalPayload.AppendView(pkt.TransportHeader().View())
+ originalPayload.Append(pkt.Data)
+ var reassembledPayload buffer.VectorisedView
+ pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ for i := 0; ; i++ {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ wantFragment := test.wantFragments[i]
+ if got := pf.RemainingFragmentCount(); got != wantFragment.remaining {
+ t.Errorf("(fragment #%d) got pf.RemainingFragmentCount() = %d, want = %d", i, got, wantFragment.remaining)
+ }
+ if copied != wantFragment.copied {
+ t.Errorf("(fragment #%d) got copied = %d, want = %d", i, copied, wantFragment.copied)
+ }
+ if offset != wantFragment.offset {
+ t.Errorf("(fragment #%d) got offset = %d, want = %d", i, offset, wantFragment.offset)
+ }
+ if more != wantFragment.more {
+ t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
+ }
+ if got := fragPkt.Size(); got > test.innerMTU {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ }
+ if got := fragPkt.AvailableHeaderBytes(); got != reserve {
+ t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
+ }
+ if got := fragPkt.TransportHeader().View().Size(); got != 0 {
+ t.Errorf("(fragment #%d) got fragPkt.TransportHeader().View().Size() = %d, want = 0", i, got)
+ }
+ reassembledPayload.Append(fragPkt.Data)
+ if !more {
+ if i != len(test.wantFragments)-1 {
+ t.Errorf("got fragment count = %d, want = %d", i, len(test.wantFragments)-1)
+ }
+ break
+ }
+ }
+ if diff := cmp.Diff(reassembledPayload.ToView(), originalPayload.ToView()); diff != "" {
+ t.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index f044867dc..9bb051a30 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "time"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -40,15 +40,15 @@ type reassembler struct {
deleted int
heap fragHeap
done bool
- creationTime time.Time
+ creationTime int64
}
-func newReassembler(id FragmentID) *reassembler {
+func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
heap: make(fragHeap, 0, 8),
- creationTime: time.Now(),
+ creationTime: clock.NowMonotonic(),
}
r.holes = append(r.holes, hole{
first: 0,
@@ -116,10 +116,6 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buf
return res, r.proto, true, consumed, nil
}
-func (r *reassembler) tooOld(timeout time.Duration) bool {
- return time.Now().Sub(r.creationTime) > timeout
-}
-
func (r *reassembler) checkDoneOrMark() bool {
r.mu.Lock()
prev := r.done
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index dff7c9dcb..a0a04a027 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -18,6 +18,8 @@ import (
"math"
"reflect"
"testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
type updateHolesInput struct {
@@ -94,7 +96,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(FragmentID{})
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4640ca95c..d436873b6 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -17,6 +17,7 @@ package ip_test
import (
"testing"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -25,26 +26,35 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
- localIpv4Addr = "\x0a\x00\x00\x01"
- localIpv4PrefixLen = 24
- remoteIpv4Addr = "\x0a\x00\x00\x02"
- ipv4SubnetAddr = "\x0a\x00\x00\x00"
- ipv4SubnetMask = "\xff\xff\xff\x00"
- ipv4Gateway = "\x0a\x00\x00\x03"
- localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- localIpv6PrefixLen = 120
- remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
- ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
- ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- nicID = 1
+ localIPv4Addr = "\x0a\x00\x00\x01"
+ remoteIPv4Addr = "\x0a\x00\x00\x02"
+ ipv4SubnetAddr = "\x0a\x00\x00\x00"
+ ipv4SubnetMask = "\xff\xff\xff\x00"
+ ipv4Gateway = "\x0a\x00\x00\x03"
+ localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
+ ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
+var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv4Addr,
+ PrefixLen: 24,
+}
+
+var localIPv6AddrWithPrefix = tcpip.AddressWithPrefix{
+ Address: localIPv6Addr,
+ PrefixLen: 120,
+}
+
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
@@ -225,7 +235,7 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStack(t *testing.T) *stack.Stack {
+func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
t.Helper()
s := stack.New(stack.Options{
@@ -237,22 +247,278 @@ func buildDummyStack(t *testing.T) *stack.Stack {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err)
+ v4Addr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: localIPv4AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v4Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v4Addr, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err)
+ v6Addr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: localIPv6AddrWithPrefix}
+ if err := s.AddProtocolAddress(nicID, v6Addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v) = %s", nicID, v6Addr, err)
}
+ return s, e
+}
+
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s, _ := buildDummyStackWithLinkEndpoint(t)
return s
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ testObject
+
+ mu struct {
+ sync.RWMutex
+ disabled bool
+ }
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (t *testInterface) Enabled() bool {
+ t.mu.RLock()
+ defer t.mu.RUnlock()
+ return !t.mu.disabled
+}
+
+func (t *testInterface) setEnabled(v bool) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mu.disabled = !v
+}
+
+func TestSourceAddressValidation(t *testing.T) {
+ rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, src tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ rxICMP func(*channel.Endpoint, tcpip.Address)
+ valid bool
+ }{
+ {
+ name: "IPv4 valid",
+ srcAddress: "\x01\x02\x03\x04",
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 valid",
+ srcAddress: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10",
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv4ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv6 unspecified",
+ srcAddress: header.IPv4Any,
+ rxICMP: rxIPv6ICMP,
+ valid: true,
+ },
+ {
+ name: "IPv4 multicast",
+ srcAddress: "\xe0\x00\x00\x01",
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv6 multicast",
+ srcAddress: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ rxICMP: rxIPv6ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 broadcast",
+ srcAddress: header.IPv4Broadcast,
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ {
+ name: "IPv4 subnet broadcast",
+ srcAddress: func() tcpip.Address {
+ subnet := localIPv4AddrWithPrefix.Subnet()
+ return subnet.Broadcast()
+ }(),
+ rxICMP: rxIPv4ICMP,
+ valid: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t)
+ test.rxICMP(e, test.srcAddress)
+
+ var wantValid uint64
+ if test.valid {
+ wantValid = 1
+ }
+
+ if got, want := s.Stats().IP.InvalidSourceAddressesReceived.Value(), 1-wantValid; got != want {
+ t.Errorf("got s.Stats().IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := s.Stats().IP.PacketsDelivered.Value(); got != wantValid {
+ t.Errorf("got s.Stats().IP.PacketsDelivered.Value() = %d, want = %d", got, wantValid)
+ }
+ })
+ }
+}
+
+func TestEnableWhenNICDisabled(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ }{
+ {
+ name: "IPv4",
+ protocolFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ },
+ {
+ name: "IPv6",
+ protocolFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var nic testInterface
+ nic.setEnabled(false)
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protocolFactory},
+ })
+ p := s.NetworkProtocolInstance(test.protoNum)
+
+ // We pass nil for all parameters except the NetworkInterface and Stack
+ // since Enable only depends on these.
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
+
+ // The endpoint should initially be disabled, regardless the NIC's enabled
+ // status.
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Attempting to enable the endpoint while the NIC is disabled should
+ // fail.
+ nic.setEnabled(false)
+ if err := ep.Enable(); err != tcpip.ErrNotPermitted {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ }
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status so we "enable" the NIC to read just the endpoint's
+ // enabled status.
+ nic.setEnabled(true)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Enabling the interface after the NIC has been enabled should succeed.
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ if !ep.Enabled() {
+ t.Fatal("got ep.Enabled() = false, want = true")
+ }
+
+ // ep should consider the NIC's enabled status when determining its own
+ // enabled status.
+ nic.setEnabled(false)
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+
+ // Disabling the endpoint when the NIC is enabled should make the endpoint
+ // disabled.
+ nic.setEnabled(true)
+ ep.Disable()
+ if ep.Enabled() {
+ t.Fatal("got ep.Enabled() = true, want = false")
+ }
+ })
+ }
+}
+
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
// Allocate and initialize the payload view.
@@ -268,12 +534,12 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv4Addr
+ nic.testObject.dstAddr = remoteIPv4Addr
+ nic.testObject.contents = payload
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -287,12 +553,21 @@ func TestIPv4Send(t *testing.T) {
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv4MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv4(view)
@@ -301,8 +576,8 @@ func TestIPv4Receive(t *testing.T) {
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
@@ -311,12 +586,12 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[header.IPv4MinimumSize:totalLen]
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -327,8 +602,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -352,18 +627,26 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
- r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
+ r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb")
if err != nil {
t.Fatal(err)
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
view := buffer.NewView(dataOffset + 8)
@@ -375,7 +658,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: "\x0a\x00\x00\xbb",
- DstAddr: localIpv4Addr,
+ DstAddr: localIPv4Addr,
})
// Create the ICMP header.
@@ -393,8 +676,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: c.fragmentOffset,
- SrcAddr: localIpv4Addr,
- DstAddr: remoteIpv4Addr,
+ SrcAddr: localIPv4Addr,
+ DstAddr: remoteIPv4Addr,
})
// Make payload be non-zero.
@@ -404,28 +687,37 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv4MinimumSize + 24
frag1 := buffer.NewView(totalLen)
@@ -437,8 +729,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -453,8 +745,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
- SrcAddr: remoteIpv4Addr,
- DstAddr: localIpv4Addr,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: localIPv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -462,12 +754,12 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv4Addr
+ nic.testObject.dstAddr = localIPv4Addr
+ nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
- r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
+ r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -480,8 +772,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ if nic.testObject.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
// Send second segment.
@@ -492,18 +784,26 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
for i := 0; i < len(payload); i++ {
@@ -517,12 +817,12 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.testObject.protocol = 123
+ nic.testObject.srcAddr = localIPv6Addr
+ nic.testObject.dstAddr = remoteIPv6Addr
+ nic.testObject.contents = payload
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -536,12 +836,20 @@ func TestIPv6Send(t *testing.T) {
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
totalLen := header.IPv6MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv6(view)
@@ -549,8 +857,8 @@ func TestIPv6Receive(t *testing.T) {
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
- SrcAddr: remoteIpv6Addr,
- DstAddr: localIpv6Addr,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: localIPv6Addr,
})
// Make payload be non-zero.
@@ -559,12 +867,12 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[header.IPv6MinimumSize:totalLen]
- r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
+ r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
@@ -576,8 +884,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.testObject.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
}
@@ -608,7 +916,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
r, err := buildIPv6Route(
- localIpv6Addr,
+ localIPv6Addr,
"\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
)
if err != nil {
@@ -616,12 +924,20 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, s)
+ nic := testInterface{
+ testObject: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject)
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
@@ -635,7 +951,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
SrcAddr: outerSrcAddr,
- DstAddr: localIpv6Addr,
+ DstAddr: localIPv6Addr,
})
// Create the ICMP header.
@@ -651,8 +967,8 @@ func TestIPv6ReceiveControl(t *testing.T) {
PayloadLength: 100,
NextHeader: 10,
HopLimit: 20,
- SrcAddr: localIpv6Addr,
- DstAddr: remoteIpv6Addr,
+ SrcAddr: localIPv6Addr,
+ DstAddr: remoteIPv6Addr,
})
// Build the fragmentation header if needed.
@@ -674,19 +990,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.testObject.protocol = 10
+ nic.testObject.srcAddr = remoteIPv6Addr
+ nic.testObject.dstAddr = localIPv6Addr
+ nic.testObject.contents = view[dataOffset:]
+ nic.testObject.typ = c.expectedTyp
+ nic.testObject.extra = c.expectedExtra
// Set ICMPv6 checksum.
- icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.testObject.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index f9c2aa980..7fc12e229 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -27,12 +28,15 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3e5cf2ad9..3407755ed 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,8 @@
package ipv4
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -40,7 +42,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -77,31 +79,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
received.Echo.Increment()
// Only send a reply if the checksum is valid.
- wantChecksum := h.Checksum()
- // Reset the checksum field to 0 to can calculate the proper
- // checksum. We'll have to reset this before we hand the packet
- // off.
+ headerChecksum := h.Checksum()
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- if gotChecksum != wantChecksum {
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
+ calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
+ h.SetChecksum(headerChecksum)
+ if calculatedChecksum != headerChecksum {
+ // It's possible that a raw socket still expects to receive this.
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
- // Make a copy of data before pkt gets sent to raw socket.
- // DeliverTransportPacket will take ownership of pkt.
- replyData := pkt.Data.Clone(nil)
- replyData.TrimFront(header.ICMPv4MinimumSize)
+ // DeliverTransportPacket will take ownership of pkt so don't use it beyond
+ // this point. Make a deep copy of the data before pkt gets sent as we will
+ // be modifying fields.
+ //
+ // TODO(gvisor.dev/issue/4399): The copy may not be needed if there are no
+ // waiting endpoints. Consider moving responsibility for doing the copy to
+ // DeliverTransportPacket so that is is only done when needed.
+ replyData := pkt.Data.ToOwnedView()
+ replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
- // It's possible that a raw socket expects to receive this.
- h.SetChecksum(wantChecksum)
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
@@ -110,39 +110,56 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
}
defer r.Release()
- // Use the remote link address from the incoming packet.
- r.ResolveWith(remoteLinkAddr)
-
- // Prepare a reply packet.
- icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize)
- copy(icmpHdr, h)
- icmpHdr.SetType(header.ICMPv4EchoReply)
- icmpHdr.SetChecksum(0)
- icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0)))
- dataVV := buffer.View(icmpHdr).ToVectorisedView()
- dataVV.Append(replyData)
+ // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the
+ // header information, we may have to change this code to handle the
+ // ICMP header no longer being in the data buffer.
+
+ // Because IP and ICMP are so closely intertwined, we need to handcraft our
+ // IP header to be able to follow RFC 792. The wording on page 13 is as
+ // follows:
+ // IP Fields:
+ // Addresses
+ // The address of the source in an echo message will be the
+ // destination of the echo reply message. To form an echo reply
+ // message, the source and destination addresses are simply reversed,
+ // the type code changed to 0, and the checksum recomputed.
+ //
+ // This was interpreted by early implementors to mean that all options must
+ // be copied from the echo request IP header to the echo reply IP header
+ // and this behaviour is still relied upon by some applications.
+ //
+ // Create a copy of the IP header we received, options and all, and change
+ // The fields we need to alter.
+ //
+ // We need to produce the entire packet in the data segment in order to
+ // use WriteHeaderIncludedPacket().
+ replyIPHdr.SetSourceAddress(r.LocalAddress)
+ replyIPHdr.SetDestinationAddress(r.RemoteAddress)
+ replyIPHdr.SetTTL(r.DefaultTTL())
+
+ replyICMPHdr := header.ICMPv4(replyData)
+ replyICMPHdr.SetType(header.ICMPv4EchoReply)
+ replyICMPHdr.SetChecksum(0)
+ replyICMPHdr.SetChecksum(^header.Checksum(replyData, 0))
+
+ replyVV := buffer.View(replyIPHdr).ToVectorisedView()
+ replyVV.AppendView(replyData)
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: dataVV,
+ Data: replyVV,
})
- // TODO(gvisor.dev/issue/3810): When adding protocol numbers into the header
- // information we will have to change this code to handle the ICMP header
- // no longer being in the data buffer.
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // Send out the reply packet.
+
+ // The checksum will be calculated so we don't need to do it here.
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
- Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, replyPkt); err != nil {
+ if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -211,18 +228,18 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonProtoUnreachable is an error where the transport protocol is
+// not supported.
+type icmpReasonProtoUnreachable struct{}
+
+func (*icmpReasonProtoUnreachable) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- sent := r.Stats().ICMP.V4PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -251,6 +268,25 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
return nil
}
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ sent := p.stack.Stats().ICMP.V4PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
+ return nil
+ }
+
networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
@@ -287,8 +323,6 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Assume any type we don't know about may be an error type.
return nil
}
- } else if transportHeader.IsEmpty() {
- return nil
}
// Now work out how much of the triggering packet we should return.
@@ -303,11 +337,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// least 8 bytes of the payload must be included. Today linux and other
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv4MinimumProcessableDatagramSize {
mtu = header.IPv4MinimumProcessableDatagramSize
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
@@ -336,19 +370,27 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
ReserveHeaderBytes: headerLen,
Data: payload,
})
+
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ switch reason.(type) {
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ case *icmpReasonProtoUnreachable:
+ icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
icmpHdr.SetType(header.ICMPv4DstUnreachable)
- icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter := sent.DstUnreachable
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ counter := sent.DstUnreachable
- if err := r.WritePacket(
+ if err := route.WritePacket(
nil, /* gso */
stack.NetworkHeaderParams{
Protocol: header.ICMPv4ProtocolNumber,
- TTL: r.DefaultTTL(),
+ TTL: route.DefaultTTL(),
TOS: stack.DefaultTOS,
},
icmpPkt,
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 254d66147..c5ac7b8b5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -18,7 +18,9 @@ package ipv4
import (
"fmt"
"sync/atomic"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -29,6 +31,15 @@ import (
)
const (
+ // As per RFC 791 section 3.2:
+ // The current recommendation for the initial timer setting is 15 seconds.
+ // This may be changed as experience with this protocol accumulates.
+ //
+ // Considering that it is an old recommendation, we use the same reassembly
+ // timeout that linux defines, which is 30 seconds:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
+ reassembleTimeout = 30 * time.Second
+
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -47,22 +58,113 @@ const (
fragmentblockSize = 8
)
+var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix()
+
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
+
+ // enabled is set to 1 when the enpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ }
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
dispatcher: dispatcher,
protocol: p,
- stack: st,
+ }
+ e.mu.addressableEndpointState.Init(e)
+ return e
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Create an endpoint to receive broadcast packets on this interface.
+ ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(ipv4BroadcastAddr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */)
+ if err != nil {
+ return err
+ }
+ // We have no need for the address endpoint.
+ ep.DecRef()
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems)
+ return err
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
+ }
+
+ // The address may have already been removed.
+ if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
}
@@ -74,31 +176,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
-}
-
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -106,98 +190,26 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is already present in pkt.NetworkHeader.
-// pkt.TransportHeader may be set. mtu includes the IP header and options. This
-// does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
- // This packet is too big, it needs to be fragmented.
- ip := header.IPv4(pkt.NetworkHeader().View())
- flags := ip.Flags()
-
- // Update mtu to take into account the header, which will exist in all
- // fragments anyway.
- innerMTU := mtu - int(ip.HeaderLength())
-
- // Round the MTU down to align to 8 bytes. Then calculate the number of
- // fragments. Calculate fragment sizes as in RFC791.
- innerMTU &^= 7
- n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
-
- outerMTU := innerMTU + int(ip.HeaderLength())
- offset := ip.FragmentOffset()
-
- // Keep the length reserved for link-layer, we need to create fragments with
- // the same reserved length.
- reservedForLink := pkt.AvailableHeaderBytes()
-
- // Destroy the packet, pull all payloads out for fragmentation.
- transHeader, data := pkt.TransportHeader().View(), pkt.Data
-
- // Where possible, the first fragment that is sent has the same
- // number of bytes reserved for header as the input packet. The link-layer
- // endpoint may depend on this for looking at, eg, L4 headers.
- transFitsFirst := len(transHeader) <= innerMTU
-
- for i := 0; i < n; i++ {
- reserve := reservedForLink + int(ip.HeaderLength())
- if i == 0 && transFitsFirst {
- // Reserve for transport header if it's going to be put in the first
- // fragment.
- reserve += len(transHeader)
- }
- fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: reserve,
- })
- fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
-
- // Copy data for the fragment.
- avail := innerMTU
-
- if n := len(transHeader); n > 0 {
- if n > avail {
- n = avail
- }
- if i == 0 && transFitsFirst {
- copy(fragPkt.TransportHeader().Push(n), transHeader)
- } else {
- fragPkt.Data.AppendView(transHeader[:n:n])
- }
- transHeader = transHeader[n:]
- avail -= n
- }
-
- if avail > 0 {
- n := data.Size()
- if n > avail {
- n = avail
- }
- data.ReadToVV(&fragPkt.Data, n)
- avail -= n
- }
-
- copied := uint16(innerMTU - avail)
-
- // Set lengths in header and calculate checksum.
- h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
- copy(h, ip)
- if i != n-1 {
- h.SetTotalLength(uint16(outerMTU))
- h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
- } else {
- h.SetTotalLength(uint16(h.HeaderLength()) + copied)
- h.SetFlagsFragmentOffset(flags, offset)
- }
- h.SetChecksum(0)
- h.SetChecksum(^h.CalculateChecksum())
- offset += copied
-
- // Send out the fragment.
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+// writePacketFragments fragments pkt and writes the results on the link
+// endpoint. The IP header must already present in the original packet. The mtu
+// is the maximum size of the packets.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
+
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
return err
}
r.Stats().IP.PacketsSent.Increment()
+ if !more {
+ break
+ }
}
+
return nil
}
@@ -219,7 +231,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -228,8 +240,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -245,7 +257,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -261,10 +273,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
+ if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -285,16 +298,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -308,7 +324,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -317,8 +333,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -377,23 +394,39 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return nil
}
+ if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
-
- return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 1122 section 3.2.1.3:
+ // When a host sends any datagram, the IP source address MUST
+ // be one of its own IP addresses (but not a broadcast or
+ // multicast address).
+ if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -449,6 +482,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
}
+
+ r.Stats().IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
// TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
@@ -458,7 +493,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.handleICMP(r, pkt)
return
}
- r.Stats().IP.PacketsDelivered.Increment()
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
@@ -468,24 +502,145 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC: 1122 Section 3.2.2.1
+ // A host SHOULD generate Destination Unreachable messages with code:
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported
+ _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
}
// Close cleans up resources associated with the endpoint.
-func (e *endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.disableLocked()
+ e.mu.addressableEndpointState.Cleanup()
+}
+
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ loopback := e.nic.IsLoopback()
+ addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
+ subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ // IPv4 has a notion of a subnet broadcast address and considers the
+ // loopback interface bound to an address's whole subnet (on linux).
+ return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
+ })
+ if addressEndpoint != nil {
+ return addressEndpoint
+ }
+
+ if !allowTemp {
+ return nil
+ }
+
+ addr := localAddr.WithPrefix()
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB)
+ if err != nil {
+ // AddAddress only returns an error if the address is already assigned,
+ // but we just checked above if the address exists so we expect no error.
+ panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err))
+ }
+ return addressEndpoint
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV4MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
- ids []uint32
- hashIV uint32
+ stack *stack.Stack
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
defaultTTL uint32
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
+ ids []uint32
+ hashIV uint32
+
fragmentation *fragmentation.Fragmentation
}
@@ -558,6 +713,20 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ if v {
+ atomic.StoreUint32(&p.forwarding, 1)
+ } else {
+ atomic.StoreUint32(&p.forwarding, 0)
+ }
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
@@ -567,19 +736,41 @@ func calculateMTU(mtu uint32) uint32 {
return mtu - header.IPv4MinimumSize
}
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ if mtu > MaxTotalSize {
+ mtu = MaxTotalSize
+ }
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ return mtu
+}
+
+// addressToUint32 translates an IPv4 address into its little endian uint32
+// representation.
+//
+// This function does the same thing as binary.LittleEndian.Uint32 but operates
+// on a tcpip.Address (a string) without the need to convert it to a byte slice,
+// which would cause an allocation.
+func addressToUint32(addr tcpip.Address) uint32 {
+ _ = addr[3] // bounds check hint to compiler
+ return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24
+}
+
// hashRoute calculates a hash value for the given route. It uses the source &
-// destination address, the transport protocol number, and a random initial
-// value (generated once on initialization) to generate the hash.
+// destination address, the transport protocol number and a 32-bit number to
+// generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
- t := r.LocalAddress
- a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- t = r.RemoteAddress
- b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ a := addressToUint32(r.LocalAddress)
+ b := addressToUint32(r.RemoteAddress)
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -590,9 +781,33 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol {
hashIV := r[buckets]
return &protocol{
+ stack: s,
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
}
}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeaderLength := len(originalIPHeader)
+ nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+
+ if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
+ }
+
+ flags := originalIPHeader.Flags()
+ if more {
+ flags |= header.IPv4FlagMoreFragments
+ }
+ nextFragIPHeader.SetFlagsFragmentOffset(flags, uint16(offset))
+ nextFragIPHeader.SetTotalLength(uint16(nextFragIPHeader.HeaderLength()) + uint16(copied))
+ nextFragIPHeader.SetChecksum(0)
+ nextFragIPHeader.SetChecksum(^nextFragIPHeader.CalculateChecksum())
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 0b3ed9483..9916d783f 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -16,19 +16,24 @@ package ipv4_test
import (
"bytes"
+ "context"
"encoding/hex"
"math"
+ "net"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
@@ -92,6 +97,276 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
+// TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and
+// checks the response.
+func TestIPv4Sanity(t *testing.T) {
+ const (
+ defaultMTU = header.IPv6MinimumMTU
+ ttl = 255
+ nicID = 1
+ randomSequence = 123
+ randomIdent = 42
+ )
+ var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ )
+
+ tests := []struct {
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ shouldFail bool
+ expectICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ options []byte
+ }{
+ {
+ name: "valid",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ },
+ // The TTL tests check that we are not rejecting an incoming packet
+ // with a zero or one TTL, which has been a point of confusion in the
+ // past as RFC 791 says: "If this field contains the value zero, then the
+ // datagram must be destroyed". However RFC 1122 section 3.2.1.7 clarifies
+ // for the case of the destination host, stating as follows.
+ //
+ // A host MUST NOT send a datagram with a Time-to-Live (TTL)
+ // value of zero.
+ //
+ // A host MUST NOT discard a datagram just because it was
+ // received with TTL less than 2.
+ {
+ name: "zero TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 0,
+ shouldFail: false,
+ },
+ {
+ name: "one TTL",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: 1,
+ shouldFail: false,
+ },
+ {
+ name: "End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{0, 0, 0, 0},
+ },
+ {
+ name: "NOP options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 1, 1},
+ },
+ {
+ name: "NOP and End options",
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{1, 1, 0, 0},
+ },
+ {
+ name: "bad header length",
+ headerLength: header.IPv4MinimumSize - 1,
+ maxTotalLength: defaultMTU,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (0)",
+ maxTotalLength: 0,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad total length (ip + icmp - 1)",
+ maxTotalLength: uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize - 1),
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: false,
+ },
+ {
+ name: "bad protocol",
+ maxTotalLength: defaultMTU,
+ transportProtocol: 99,
+ TTL: ttl,
+ shouldFail: true,
+ expectICMP: true,
+ ICMPType: header.ICMPv4DstUnreachable,
+ ICMPCode: header.ICMPv4ProtoUnreachable,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ })
+ // We expect at most a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Round up the header size to the next multiple of 4 as RFC 791, page 11
+ // says: "Internet Header Length is the length of the internet header
+ // in 32 bit words..." and on page 23: "The internet header padding is
+ // used to ensure that the internet header ends on a 32 bit boundary."
+ ipHeaderLength := ((header.IPv4MinimumSize + len(test.options)) + header.IPv4IHLStride - 1) & ^(header.IPv4IHLStride - 1)
+
+ if ipHeaderLength > header.IPv4MaximumHeaderSize {
+ t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize)
+ }
+ totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(totalLen))
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+
+ // Specify ident/seq to make sure we get the same in the response.
+ icmp.SetIdent(randomIdent)
+ icmp.SetSequence(randomSequence)
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(^header.Checksum(icmp, 0))
+ ip := header.IPv4(hdr.Prepend(ipHeaderLength))
+ if test.maxTotalLength < totalLen {
+ totalLen = test.maxTotalLength
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHeaderLength),
+ TotalLength: totalLen,
+ Protocol: test.transportProtocol,
+ TTL: test.TTL,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: ipv4Addr.Address,
+ })
+ if n := copy(ip.Options(), test.options); n != len(test.options) {
+ t.Fatalf("options larger than available space: copied %d/%d bytes", n, len(test.options))
+ }
+ // Override the correct value if the test case specified one.
+ if test.headerLength != 0 {
+ ip.SetHeaderLength(test.headerLength)
+ }
+ requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, requestPkt)
+ reply, ok := e.Read()
+ if !ok {
+ if test.shouldFail {
+ if test.expectICMP {
+ t.Fatal("expected ICMP error response missing")
+ }
+ return // Expected silent failure.
+ }
+ t.Fatal("expected ICMP echo reply missing")
+ }
+
+ // Check the route that brought the packet to us.
+ if reply.Route.LocalAddress != ipv4Addr.Address {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
+ }
+ if reply.Route.RemoteAddress != remoteIPv4Addr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
+ }
+
+ // Make sure it's all in one buffer.
+ vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
+ replyIPHeader := header.IPv4(vv.ToView())
+
+ // At this stage we only know it's an IP header so verify that much.
+ checker.IPv4(t, replyIPHeader,
+ checker.SrcAddr(ipv4Addr.Address),
+ checker.DstAddr(remoteIPv4Addr),
+ )
+
+ // All expected responses are ICMP packets.
+ if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
+ t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ }
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
+
+ // Sanity check the response.
+ switch replyICMPHeader.Type() {
+ case header.ICMPv4DstUnreachable:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
+ if !test.shouldFail || !test.expectICMP {
+ t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ return
+ case header.ICMPv4EchoReply:
+ checker.IPv4(t, replyIPHeader,
+ checker.IPv4HeaderLength(ipHeaderLength),
+ checker.IPv4Options(test.options),
+ checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.ICMPv4(
+ checker.ICMPv4Code(header.ICMPv4UnusedCode),
+ checker.ICMPv4Seq(randomSequence),
+ checker.ICMPv4Ident(randomIdent),
+ checker.ICMPv4Checksum(),
+ ),
+ )
+ if test.shouldFail {
+ t.Fatalf("unexpected Echo Reply packet\n")
+ }
+ default:
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ }
+ })
+ }
+}
+
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
@@ -123,16 +398,6 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if i == 0 {
- got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
- // sourcePacketInfo does not have NetworkHeader added, simulate one.
- want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
- // Check that it kept the transport header in packet.TransportHeader if
- // it fits in the first fragment.
- if want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
- }
- }
if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
@@ -162,6 +427,8 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
}
func TestFragmentation(t *testing.T) {
+ const ttl = 42
+
var manyPayloadViewsSizes [1000]int
for i := range manyPayloadViewsSizes {
manyPayloadViewsSizes[i] = 7
@@ -175,15 +442,15 @@ func TestFragmentation(t *testing.T) {
payloadViewsSizes []int
expectedFrags int
}{
- {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
+ {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
{"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
+ {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
+ {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
+ {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
+ {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
}
for _, ft := range fragTests {
@@ -194,11 +461,11 @@ func TestFragmentation(t *testing.T) {
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
if err != nil {
- t.Errorf("got err = %s, want = nil", err)
+ t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
if got := len(ep.WrittenPackets); got != ft.expectedFrags {
@@ -207,6 +474,9 @@ func TestFragmentation(t *testing.T) {
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
compareFragments(t, ep.WrittenPackets, source, ft.mtu)
})
}
@@ -215,36 +485,70 @@ func TestFragmentation(t *testing.T) {
// TestFragmentationErrors checks that errors are returned from write packet
// correctly.
func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ expectedError := tcpip.ErrAborted
fragTests := []struct {
description string
mtu uint32
transportHeaderLength int
- payloadViewsSizes []int
- err *tcpip.Error
+ payloadSize int
allowPackets int
+ fragmentCount int
}{
- {"NoFrag", 2000, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnFirstFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 0},
- {"ErrorOnSecondFrag", 500, 0, []int{1000}, tcpip.ErrAborted, 1},
- {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, tcpip.ErrAborted, 0},
+ {
+ description: "No frag",
+ mtu: 2000,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 1,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 0,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 500,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ allowPackets: 1,
+ fragmentCount: 3,
+ },
+ {
+ description: "Error on first frag MTU smaller than header",
+ mtu: 500,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ allowPackets: 0,
+ fragmentCount: 4,
+ },
}
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.err, ft.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
- TTL: 42,
+ TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.err {
- t.Errorf("got WritePacket() = %s, want = %s", err, ft.err)
+ if err != expectedError {
+ t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
}
if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
}
+ if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ }
})
}
}
@@ -1005,6 +1309,7 @@ func TestReceiveFragments(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
+
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
@@ -1040,7 +1345,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to find filter table")
}
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
@@ -1062,10 +1367,10 @@ func TestWriteStats(t *testing.T) {
}
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
@@ -1150,12 +1455,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\x10\x00\x00\x02"
)
if err := s.AddAddress(1, ipv4.ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ipv4.ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ipv4.ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ mask := tcpip.AddressMask(header.IPv4Broadcast)
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1164,7 +1470,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ipv4.ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s", src, dst, ipv4.ProtocolNumber, err)
}
return rt
}
@@ -1188,3 +1494,204 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
+ PrefixLen: 24,
+ },
+ }
+ host2IPv4Addr = tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
+ PrefixLen: 8,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + header.UDPMinimumSize,
+ TTL: ipv4.DefaultTTL,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != header.IPv4ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply),
+ checker.ICMPv4Code(header.ICMPv4UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a ARP request since link address resolution should be
+ // performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != arp.ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress)
+ }
+ rep := header.ARP(p.Pkt.NetworkHeader().View())
+ if got := rep.Op(); got != header.ARPRequest {
+ t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest)
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != host1NICLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, host1NICLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != host1IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, host1IPv4Addr.AddressWithPrefix.Address)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != host2IPv4Addr.AddressWithPrefix.Address {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, host2IPv4Addr.AddressWithPrefix.Address)
+ }
+ }
+
+ // Send an ARP reply to complete link address resolution.
+ {
+ hdr := buffer.View(make([]byte, header.ARPSize))
+ packet := header.ARP(hdr)
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), host2NICLinkAddr)
+ copy(packet.ProtocolAddressSender(), host2IPv4Addr.AddressWithPrefix.Address)
+ copy(packet.HardwareAddressTarget(), host1NICLinkAddr)
+ copy(packet.ProtocolAddressTarget(), host1IPv4Addr.AddressWithPrefix.Address)
+ e.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index 8bd8f5c52..a30437f02 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -5,16 +5,20 @@ package(licenses = ["notice"])
go_library(
name = "ipv6",
srcs = [
+ "dhcpv6configurationfromndpra_string.go",
"icmp.go",
"ipv6.go",
+ "ndp.go",
],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
],
)
@@ -38,6 +42,7 @@ go_test(
"//pkg/tcpip/network/testutil",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
new file mode 100644
index 000000000..09ba133b1
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/dhcpv6configurationfromndpra_string.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
+
+package ipv6
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[DHCPv6NoConfiguration-1]
+ _ = x[DHCPv6ManagedAddress-2]
+ _ = x[DHCPv6OtherConfigurations-3]
+}
+
+const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
+
+var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
+
+func (i DHCPv6ConfigurationFromNDPRA) String() string {
+ i -= 1
+ if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
+ return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")"
+ }
+ return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
+}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index dd3295b31..a454f6c34 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -39,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -207,14 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- s := r.Stack()
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now, drop this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// If the target address is tentative and the source of the packet is a
// unicast (specified) address, then the source of the packet is
// attempting to perform address resolution on the target. In this case,
@@ -227,7 +222,20 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if r.RemoteAddress == header.IPv6Any {
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
}
// Do not handle neighbor solicitations targeted to an address that is
@@ -240,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -275,7 +283,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if e.nud != nil {
e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
+ }
+
+ // As per RFC 4861 section 7.1.1:
+ // A node MUST silently discard any received Neighbor Solicitation
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP source address is the unspecified address, the IP
+ // destination address is a solicited-node multicast address.
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ received.Invalid.Increment()
+ return
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -314,18 +333,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
optsSerializer := header.NDPOptionsSerializer{
header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
}
+ neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborAdvertSize,
})
- packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize))
packet.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(packet.NDPPayload())
na.SetSolicitedFlag(solicited)
na.SetOverrideFlag(true)
na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
+ na.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
@@ -342,7 +361,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize {
+ if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
received.Invalid.Increment()
return
}
@@ -353,20 +372,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// NDP datagrams are very small and ToView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.ToView())
targetAddr := na.TargetAddress()
- s := r.Stack()
-
- if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
- // We will only get an error if the NIC is unrecognized, which should not
- // happen. For now short-circuit this packet.
- //
- // TODO(b/141002840): Handle this better?
- return
- } else if isTentative {
+ if e.hasTentativeAddr(targetAddr) {
// We just got an NA from a node that owns an address we are performing
// DAD on, implying the address is not unique. In this case we let the
// stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- s.DupTentativeAddrDetected(e.nicID, targetAddr)
+ //
+ // We would get an error if the address no longer exists or the address
+ // is no longer tentative (DAD resolved between the call to
+ // hasTentativeAddr and this point). Both of these are valid scenarios:
+ // 1) An address may be removed at any time.
+ // 2) As per RFC 4862 section 5.4, DAD is not a perfect:
+ // "Note that the method for detecting duplicates
+ // is not completely reliable, and it is possible that duplicate
+ // addresses will still exist"
+ //
+ // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
+ // address is detected for an assigned address.
+ if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
+ }
return
}
@@ -396,7 +421,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// address cache with the link address for the target of the message.
if len(targetLinkAddr) != 0 {
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
return
}
@@ -415,8 +440,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- remoteLinkAddr := r.RemoteLinkAddress
-
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
localAddr := r.LocalAddress
@@ -424,16 +447,13 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
}
defer r.Release()
- // Use the link address from the source of the original packet.
- r.ResolveWith(remoteLinkAddr)
-
replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
Data: pkt.Data,
@@ -568,9 +588,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
}
- // Tell the NIC to handle the RA.
- stack := r.Stack()
- stack.HandleNDPRA(e.nicID, routerAddr, ra)
+ e.mu.Lock()
+ e.mu.ndp.handleRA(routerAddr, ra)
+ e.mu.Unlock()
case header.ICMPv6RedirectMsg:
// TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating
@@ -600,18 +620,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
}
-const (
- ndpSolicitedFlag = 1 << 6
- ndpOverrideFlag = 1 << 5
-
- ndpOptSrcLinkAddr = 1
- ndpOptDstLinkAddr = 2
-
- icmpV6FlagOffset = 4
- icmpV6OptOffset = 24
- icmpV6LengthOffset = 25
-)
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -621,31 +629,37 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- snaddr := header.SolicitedNodeAddr(addr)
-
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
// route here. Note, we would need the nicID to do this properly so the right
// NIC (associated to linkEP) is used to send the NDP NS message.
- r := &stack.Route{
+ r := stack.Route{
LocalAddress: localAddr,
- RemoteAddress: snaddr,
+ RemoteAddress: addr,
RemoteLinkAddress: remoteLinkAddr,
}
+
+ // If a remote address is not already known, then send a multicast
+ // solicitation since multicast addresses have a static mapping to link
+ // addresses.
if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ r.RemoteAddress = header.SolicitedNodeAddr(addr)
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
}
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ }
+ neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
})
- icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
- icmpHdr.SetType(header.ICMPv6NeighborSolicit)
- copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
- icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
- icmpHdr[icmpV6LengthOffset] = 1
- copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
+ packet.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(packet.NDPPayload())
+ ns.SetTargetAddress(addr)
+ ns.Options().Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(pkt.Size())
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
@@ -658,7 +672,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -676,6 +690,36 @@ type icmpReason interface {
isICMPReason()
}
+// icmpReasonParameterProblem is an error during processing of extension headers
+// or the fixed header defined in RFC 4443 section 3.4.
+type icmpReasonParameterProblem struct {
+ code header.ICMPv6Code
+
+ // respondToMulticast indicates that we are sending a packet that falls under
+ // the exception outlined by RFC 4443 section 2.4 point e.3 exception 2:
+ //
+ // (e.3) A packet destined to an IPv6 multicast address. (There are
+ // two exceptions to this rule: (1) the Packet Too Big Message
+ // (Section 3.2) to allow Path MTU discovery to work for IPv6
+ // multicast, and (2) the Parameter Problem Message, Code 2
+ // (Section 3.4) reporting an unrecognized IPv6 option (see
+ // Section 4.2 of [IPv6]) that has the Option Type highest-
+ // order two bits set to 10).
+ respondToMulticast bool
+
+ // pointer is defined in the RFC 4443 setion 3.4 which reads:
+ //
+ // Pointer Identifies the octet offset within the invoking packet
+ // where the error was detected.
+ //
+ // The pointer will point beyond the end of the ICMPv6
+ // packet if the field in error is beyond what can fit
+ // in the maximum size of an ICMPv6 error message.
+ pointer uint32
+}
+
+func (*icmpReasonParameterProblem) isICMPReason() {}
+
// icmpReasonPortUnreachable is an error where the transport protocol has no
// listener and no alternative means to inform the sender.
type icmpReasonPortUnreachable struct{}
@@ -684,18 +728,11 @@ func (*icmpReasonPortUnreachable) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
- stats := r.Stats().ICMP
- sent := stats.V6PacketsSent
- if !r.Stack().AllowICMPMessage() {
- sent.RateLimited.Increment()
- return nil
- }
-
+func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
- // TODO(b/164522993) There are exceptions to this rule.
+ // There are exceptions to this rule.
// See: point e.3) RFC 4443 section-2.4
//
// (e) An ICMPv6 error message MUST NOT be originated as a result of
@@ -713,7 +750,32 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// Section 4.2 of [IPv6]) that has the Option Type highest-
// order two bits set to 10).
//
- if header.IsV6MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv6Any {
+ var allowResponseToMulticast bool
+ if reason, ok := reason.(*icmpReasonParameterProblem); ok {
+ allowResponseToMulticast = reason.respondToMulticast
+ }
+
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ return nil
+ }
+
+ // Even if we were able to receive a packet from some remote, we may not have
+ // a route to it - the remote may be blocked via routing rules. We must always
+ // consult our routing table and find a route to the remote before sending any
+ // packet.
+ route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ // From this point on, the incoming route should no longer be used; route
+ // must be used to send the ICMP error.
+ r = nil
+
+ stats := p.stack.Stats().ICMP
+ sent := stats.V6PacketsSent
+ if !p.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
return nil
}
@@ -743,11 +805,11 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
// packet that caused the error) as possible without making
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
- mtu := int(r.MTU())
+ mtu := int(route.MTU())
if mtu > header.IPv6MinimumMTU {
mtu = header.IPv6MinimumMTU
}
- headerLen := int(r.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
+ headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
available := int(mtu) - headerLen
if available < header.IPv6MinimumSize {
return nil
@@ -766,12 +828,30 @@ func returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tc
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- icmpHdr.SetType(header.ICMPv6DstUnreachable)
- icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, newPkt.Data))
- counter := sent.DstUnreachable
- err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, newPkt)
- if err != nil {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
+ case *icmpReasonParameterProblem:
+ icmpHdr.SetType(header.ICMPv6ParamProblem)
+ icmpHdr.SetCode(reason.code)
+ icmpHdr.SetTypeSpecific(reason.pointer)
+ counter = sent.ParamProblem
+ case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ counter = sent.DstUnreachable
+ default:
+ panic(fmt.Sprintf("unsupported ICMP type %T", reason))
+ }
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, route.LocalAddress, route.RemoteAddress, newPkt.Data))
+ if err := route.WritePacket(
+ nil, /* gso */
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ },
+ newPkt,
+ ); err != nil {
sent.Dropped.Increment()
return err
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index dd58022d6..3affcc4e4 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -16,17 +16,21 @@ package ipv6
import (
"context"
+ "net"
"reflect"
"strings"
"testing"
+ "time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -39,6 +43,9 @@ const (
defaultChannelSize = 1
defaultMTU = 65536
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 30 * time.Second
)
var (
@@ -50,6 +57,10 @@ type stubLinkEndpoint struct {
stack.LinkEndpoint
}
+func (*stubLinkEndpoint) MTU() uint32 {
+ return defaultMTU
+}
+
func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
// Indicate that resolution for link layer addresses is required to send
// packets over this link. This is needed so the NIC knows to allocate a
@@ -103,6 +114,28 @@ func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Lin
func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkLinkEndpoint
+}
+
+func (*testInterface) ID() tcpip.NICID {
+ return 0
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -150,9 +183,13 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
@@ -288,9 +325,13 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err)
@@ -1187,24 +1228,30 @@ func TestLinkAddressRequest(t *testing.T) {
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectedLinkAddr tcpip.LinkAddress
+ expectedAddr tcpip.Address
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectLinkAddr: linkAddr1,
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectedLinkAddr: linkAddr1,
+ expectedAddr: lladdr0,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectLinkAddr: mcaddr,
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectedLinkAddr: mcaddr,
+ expectedAddr: snaddr,
},
}
for _, test := range tests {
- p := NewProtocol(nil)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
linkRes, ok := p.(stack.LinkAddressResolver)
if !ok {
t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
@@ -1219,9 +1266,229 @@ func TestLinkAddressRequest(t *testing.T) {
if !ok {
t.Fatal("expected to send a link address request")
}
+ if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ }
+ if pkt.Route.RemoteAddress != test.expectedAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ }
+ if pkt.Route.LocalAddress != lladdr1 {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ }
+}
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+func TestPacketQueing(t *testing.T) {
+ const nicID = 1
+
+ var (
+ host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+
+ host1IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::1").To16()),
+ PrefixLen: 64,
+ },
}
+ host2IPv6Addr = tcpip.ProtocolAddress{
+ Protocol: ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("a::2").To16()),
+ PrefixLen: 64,
+ },
+ }
+ )
+
+ tests := []struct {
+ name string
+ rxPkt func(*channel.Endpoint)
+ checkResp func(*testing.T, *channel.Endpoint)
+ }{
+ {
+ name: "ICMP Error",
+ rxPkt: func(e *channel.Endpoint) {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize)
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ },
+ },
+
+ {
+ name: "Ping",
+ rxPkt: func(e *channel.Endpoint) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: DefaultTTL,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ },
+ checkResp: func(t *testing.T, e *channel.Endpoint) {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ if p.Route.RemoteLinkAddress != host2NICLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply),
+ checker.ICMPv6Code(header.ICMPv6UnusedCode)))
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+
+ e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddProtocolAddress(nicID, host1IPv6Addr); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv6Addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: nicID,
+ },
+ })
+
+ // Receive a packet to trigger link resolution before a response is sent.
+ test.rxPkt(e)
+
+ // Wait for a neighbor solicitation since link address resolution should
+ // be performed.
+ {
+ p, ok := e.ReadContext(context.Background())
+ if !ok {
+ t.Fatalf("timed out waiting for packet")
+ }
+ if p.Proto != ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber)
+ }
+ snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address),
+ checker.DstAddr(snmc),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(host2IPv6Addr.AddressWithPrefix.Address),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(host1NICLinkAddr)}),
+ ))
+ }
+
+ // Send a neighbor advertisement to complete link address resolution.
+ {
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(host2NICLinkAddr),
+ })
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ DstAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ })
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ // Expect the response now that the link address has resolved.
+ test.checkResp(t, e)
+
+ // Since link resolution was already performed, it shouldn't be performed
+ // again.
+ test.rxPkt(e)
+ test.checkResp(t, e)
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e436c6a9e..2bd8f4ece 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,18 +16,33 @@
package ipv6
import (
+ "encoding/binary"
"fmt"
+ "hash/fnv"
+ "sort"
"sync/atomic"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
+ // As per RFC 8200 section 4.5:
+ // If insufficient fragments are received to complete reassembly of a packet
+ // within 60 seconds of the reception of the first-arriving fragment of that
+ // packet, reassembly of that packet must be abandoned.
+ //
+ // Linux also uses 60 seconds for reassembly timeout:
+ // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
+ reassembleTimeout = 60 * time.Second
+
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -38,16 +53,306 @@ const (
// DefaultTTL is the default hop limit for IPv6 Packets egressed by
// Netstack.
DefaultTTL = 64
+
+ // buckets for fragment identifiers
+ buckets = 2048
)
+var _ stack.GroupAddressableEndpoint = (*endpoint)(nil)
+var _ stack.AddressableEndpoint = (*endpoint)(nil)
+var _ stack.NetworkEndpoint = (*endpoint)(nil)
+var _ stack.NDPEndpoint = (*endpoint)(nil)
+var _ NDPEndpoint = (*endpoint)(nil)
+
type endpoint struct {
- nicID tcpip.NICID
- linkEP stack.LinkEndpoint
+ nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
+
+ // enabled is set to 1 when the endpoint is enabled and 0 when it is
+ // disabled.
+ //
+ // Must be accessed using atomic operations.
+ enabled uint32
+
+ mu struct {
+ sync.RWMutex
+
+ addressableEndpointState stack.AddressableEndpointState
+ ndp ndpState
+ }
+}
+
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it is passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
+// InvalidateDefaultRouter implements stack.NDPEndpoint.
+func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.invalidateDefaultRouter(rtr)
+}
+
+// SetNDPConfigurations implements NDPEndpoint.
+func (e *endpoint) SetNDPConfigurations(c NDPConfigurations) {
+ c.validate()
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.mu.ndp.configs = c
+}
+
+// hasTentativeAddr returns true if addr is tentative on e.
+func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
+ e.mu.RLock()
+ addressEndpoint := e.getAddressRLocked(addr)
+ e.mu.RUnlock()
+ return addressEndpoint != nil && addressEndpoint.GetKind() == stack.PermanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform e that a tentative addr is a
+// duplicate on a link.
+//
+// dupTentativeAddrDetected removes the tentative address if it exists. If the
+// address was generated via SLAAC, an attempt is made to generate a new
+// address.
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return tcpip.ErrBadAddress
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
+ // attempt will be made to generate a new address for it.
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := addressEndpoint.AddressWithPrefix().Subnet()
+
+ switch t := addressEndpoint.ConfigType(); t {
+ case stack.AddressConfigStatic:
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.regenerateSLAACAddr(prefix)
+ case stack.AddressConfigSlaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ e.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ default:
+ panic(fmt.Sprintf("unrecognized address config type = %d", t))
+ }
+
+ return nil
+}
+
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if !e.Enabled() {
+ return
+ }
+
+ if forwarding {
+ // When transitioning into an IPv6 router, host-only state (NDP discovered
+ // routers, discovered on-link prefixes, and auto-generated addresses) is
+ // cleaned up/invalidated and NDP router solicitations are stopped.
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(true /* hostOnly */)
+ } else {
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started.
+ e.mu.ndp.startSolicitingRouters()
+ }
+}
+
+// Enable implements stack.NetworkEndpoint.
+func (e *endpoint) Enable() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // If the NIC is not enabled, the endpoint can't do anything meaningful so
+ // don't enable the endpoint.
+ if !e.nic.Enabled() {
+ return tcpip.ErrNotPermitted
+ }
+
+ // If the endpoint is already enabled, there is nothing for it to do.
+ if !e.setEnabled(true) {
+ return nil
+ }
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ //
+ // Also auto-generate an IPv6 link-local address based on the endpoint's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the endpoint
+ // was last enabled, other devices may have acquired the same addresses.
+ var err *tcpip.Error
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if !header.IsV6UnicastAddress(addr) {
+ return true
+ }
+
+ switch addressEndpoint.GetKind() {
+ case stack.Permanent:
+ addressEndpoint.SetKind(stack.PermanentTentative)
+ fallthrough
+ case stack.PermanentTentative:
+ err = e.mu.ndp.startDuplicateAddressDetection(addr, addressEndpoint)
+ return err == nil
+ default:
+ return true
+ }
+ })
+ if err != nil {
+ return err
+ }
+
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyway.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !e.protocol.Forwarding() {
+ e.mu.ndp.startSolicitingRouters()
+ }
+
+ return nil
+}
+
+// Enabled implements stack.NetworkEndpoint.
+func (e *endpoint) Enabled() bool {
+ return e.nic.Enabled() && e.isEnabled()
+}
+
+// isEnabled returns true if the endpoint is enabled, regardless of the
+// enabled status of the NIC.
+func (e *endpoint) isEnabled() bool {
+ return atomic.LoadUint32(&e.enabled) == 1
+}
+
+// setEnabled sets the enabled status for the endpoint.
+//
+// Returns true if the enabled status was updated.
+func (e *endpoint) setEnabled(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&e.enabled, 1) == 0
+ }
+ return atomic.SwapUint32(&e.enabled, 0) == 1
+}
+
+// Disable implements stack.NetworkEndpoint.
+func (e *endpoint) Disable() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.disableLocked()
+}
+
+func (e *endpoint) disableLocked() {
+ if !e.setEnabled(false) {
+ return
+ }
+
+ e.mu.ndp.stopSolicitingRouters()
+ e.mu.ndp.cleanupState(false /* hostOnly */)
+ e.stopDADForPermanentAddressesLocked()
+
+ // The endpoint may have already left the multicast group.
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
+ }
+}
+
+// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) stopDADForPermanentAddressesLocked() {
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+
+ return true
+ })
}
// DefaultTTL is the default hop limit for this endpoint.
@@ -58,31 +363,13 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.linkEP.MTU())
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
-// Capabilities implements stack.NetworkEndpoint.Capabilities.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
+ return calculateMTU(e.nic.MTU())
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
@@ -96,7 +383,44 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+ pkt.NetworkProtocolNumber = ProtocolNumber
+}
+
+func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
+ return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+}
+
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet. The mtu is the maximum size of the packets. The transport
+// header protocol number is required to avoid parsing the IPv6 extension
+// headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+ if fragMTU < pkt.TransportHeader().View().Size() {
+ // As per RFC 8200 Section 4.5, the Transport Header is expected to be small
+ // enough to fit in the first fragment.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader, transProto, id)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ break
+ }
+ }
+
+ return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
@@ -105,8 +429,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -122,7 +446,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -143,14 +467,29 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if e.packetMustBeFragmented(pkt, gso) {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
+ }
+
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
+
r.Stats().IP.PacketsSent.Increment()
return nil
}
-// WritePackets implements stack.LinkEndpoint.WritePackets.
+// WritePackets implements stack.NetworkEndpoint.WritePackets.
func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
@@ -161,18 +500,38 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
+ if e.packetMustBeFragmented(pb, gso) {
+ current := pb
+ _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(current, fragPkt)
+ current = current.Next()
+ return nil
+ })
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ // The fragmented packet can be released. The rest of the packets can be
+ // processed.
+ pkts.Remove(pb)
+ pb = current
+ }
}
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ }
return n, err
}
r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
@@ -186,7 +545,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -195,8 +554,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
continue
}
}
- if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -219,12 +579,24 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ if !e.isEnabled() {
+ return
+ }
+
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
+ // As per RFC 4291 section 2.7:
+ // Multicast addresses must not be used as source addresses in IPv6
+ // packets or appear in any Routing header.
+ if header.IsV6MulticastAddress(r.RemoteAddress) {
+ r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
+
// vv consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
@@ -236,15 +608,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
hasFragmentHeader := false
// iptables filtering. All packets that reach here are intended for
- // this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ // this machine and need not be forwarded.
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
return
}
- for firstHeader := true; ; firstHeader = false {
+ for {
+ // Keep track of the start of the previous header so we can report the
+ // special case of a Hop by Hop at a location other than at the start.
+ previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -258,11 +633,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6HopByHopOptionsExtHdr:
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
- // (unrecognized next header) error in response to an extension header's
- // Next Header field with the Hop By Hop extension header identifier.
- if !firstHeader {
+ if previousHeaderStart != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: previousHeaderStart,
+ }, pkt)
return
}
@@ -284,13 +659,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
@@ -301,16 +688,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.4, if a node encounters a routing header with
// an unrecognized routing type value, with a non-zero Segments Left
// value, the node must discard the packet and send an ICMP Parameter
- // Problem, Code 0. If the Segments Left is 0, the node must ignore the
- // Routing extension header and process the next header in the packet.
+ // Problem, Code 0 to the packet's Source Address, pointing to the
+ // unrecognized Routing Type.
+ //
+ // If the Segments Left is 0, the node must ignore the Routing extension
+ // header and process the next header in the packet.
//
// Note, the stack does not yet handle any type of routing extension
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
- // unrecognized routing types with a non-zero Segments Left value.
if extHdr.SegmentsLeft() != 0 {
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
return
}
@@ -445,13 +836,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionSkip:
case header.IPv6OptionUnknownActionDiscard:
return
- case header.IPv6OptionUnknownActionDiscardSendICMP:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
- return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
- // unrecognized IPv6 extension header options.
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ return
+ }
+ fallthrough
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // This case satisfies a requirement of RFC 8200 section 4.2
+ // which states that an unknown option starting with bits [10] should:
+ //
+ // discard the packet and, regardless of whether or not the
+ // packet's Destination Address was a multicast address, send an
+ // ICMP Parameter Problem, Code 2, message to the packet's
+ // Source Address, pointing to the unrecognized Option Type.
+ //
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownOption,
+ pointer: it.ParseOffset() + optsIt.OptionOffset(),
+ respondToMulticast: true,
+ }, pkt)
return
default:
panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
@@ -470,13 +873,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
+ r.Stats().IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
e.handleICMP(r, pkt, hasFragmentHeader)
} else {
r.Stats().IP.PacketsDelivered.Increment()
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
@@ -485,18 +887,41 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ case stack.TransportPacketProtocolUnreachable:
+ // As per RFC 8200 section 4. (page 7):
+ // Extension headers are numbered from IANA IP Protocol Numbers
+ // [IANA-PN], the same values used for IPv4 and IPv6. When
+ // processing a sequence of Next Header values in a packet, the
+ // first one that is not an extension header [IANA-EH] indicates
+ // that the next item in the packet is the corresponding upper-layer
+ // header.
+ // With more related information on page 8:
+ // If, as a result of processing a header, the destination node is
+ // required to proceed to the next header but the Next Header value
+ // in the current header is unrecognized by the node, it should
+ // discard the packet and send an ICMP Parameter Problem message to
+ // the source of the packet, with an ICMP Code value of 1
+ // ("unrecognized Next Header type encountered") and the ICMP
+ // Pointer field containing the offset of the unrecognized value
+ // within the original packet.
+ //
+ // Which when taken together indicate that an unknown protocol should
+ // be treated as an unrecognized next header value.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
}
default:
- // If we receive a packet for an extension header we do not yet handle,
- // drop the packet for now.
- //
- // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
- // in response to unrecognized next header values.
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6UnknownHeader,
+ pointer: it.ParseOffset(),
+ }, pkt)
r.Stats().UnknownProtocolRcvdPackets.Increment()
return
}
@@ -504,19 +929,343 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
// Close cleans up resources associated with the endpoint.
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.disableLocked()
+ e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
+ e.stopDADForPermanentAddressesLocked()
+ e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
+
+ e.protocol.forgetEndpoint(e)
+}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
+// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ // TODO(b/169350103): add checks here after making sure we no longer receive
+ // an empty address.
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
+}
+
+// addAndAcquirePermanentAddressLocked is like AddAndAcquirePermanentAddress but
+// with locking requirements.
+//
+// addAndAcquirePermanentAddressLocked also joins the passed address's
+// solicited-node multicast group and start duplicate address detection.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+ addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
+ if err != nil {
+ return nil, err
+ }
+
+ if !header.IsV6UnicastAddress(addr.Address) {
+ return addressEndpoint, nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil {
+ return nil, err
+ }
+
+ addressEndpoint.SetKind(stack.PermanentTentative)
+
+ if e.Enabled() {
+ if err := e.mu.ndp.startDuplicateAddressDetection(addr.Address, addressEndpoint); err != nil {
+ return nil, err
+ }
+ }
+
+ return addressEndpoint, nil
+}
+
+// RemovePermanentAddress implements stack.AddressableEndpoint.
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ return e.removePermanentEndpointLocked(addressEndpoint, true)
+}
+
+// removePermanentEndpointLocked is like removePermanentAddressLocked except
+// it works with a stack.AddressEndpoint.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := addressEndpoint.AddressWithPrefix()
+ unicast := header.IsV6UnicastAddress(addr.Address)
+ if unicast {
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch addressEndpoint.ConfigType() {
+ case stack.AddressConfigSlaac:
+ e.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case stack.AddressConfigSlaacTemp:
+ e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
+ }
+
+ if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
+ return err
+ }
+
+ if !unicast {
+ return nil
+ }
+
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
+ return nil
+}
+
+// hasPermanentAddressLocked returns true if the endpoint has a permanent
+// address equal to the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool {
+ addressEndpoint := e.getAddressRLocked(addr)
+ if addressEndpoint == nil {
+ return false
+ }
+ return addressEndpoint.GetKind().IsPermanent()
+}
+
+// getAddressRLocked returns the endpoint for the passed address.
+//
+// Precondition: e.mu must be read or write locked.
+func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr)
+}
+
+// MainAddress implements stack.AddressableEndpoint.
+func (e *endpoint) MainAddress() tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.MainAddress()
+}
+
+// AcquireAssignedAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.acquireAddressOrCreateTempLocked(localAddr, allowTemp, tempPEB)
+}
+
+// acquireAddressOrCreateTempLocked is like AcquireAssignedAddress but with
+// locking requirements.
+//
+// Precondition: e.mu must be write locked.
+func (e *endpoint) acquireAddressOrCreateTempLocked(localAddr tcpip.Address, allowTemp bool, tempPEB stack.PrimaryEndpointBehavior) stack.AddressEndpoint {
+ return e.mu.addressableEndpointState.AcquireAssignedAddress(localAddr, allowTemp, tempPEB)
+}
+
+// AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint.
+func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired)
+}
+
+// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress
+// but with locking requirements.
+//
+// Precondition: e.mu must be read locked.
+func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint {
+ // addrCandidate is a candidate for Source Address Selection, as per
+ // RFC 6724 section 5.
+ type addrCandidate struct {
+ addressEndpoint stack.AddressEndpoint
+ scope header.IPv6AddressScope
+ }
+
+ if len(remoteAddr) == 0 {
+ return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired)
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ var cs []addrCandidate
+ e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !addressEndpoint.IsAssigned(allowExpired) {
+ return
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, addrCandidate{
+ addressEndpoint: addressEndpoint,
+ scope: scope,
+ })
+ })
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return true
+ }
+ if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.addressEndpoint.Deprecated(), sb.addressEndpoint.Deprecated(); saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if c.addressEndpoint.IncRef() {
+ return c.addressEndpoint
+ }
+ }
+
+ return nil
+}
+
+// PrimaryAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PrimaryAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PrimaryAddresses()
+}
+
+// PermanentAddresses implements stack.AddressableEndpoint.
+func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.PermanentAddresses()
+}
+
+// JoinGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ if !header.IsV6MulticastAddress(addr) {
+ return false, tcpip.ErrBadAddress
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.JoinGroup(addr)
+}
+
+// LeaveGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.mu.addressableEndpointState.LeaveGroup(addr)
+}
+
+// IsInGroup implements stack.GroupAddressableEndpoint.
+func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ return e.mu.addressableEndpointState.IsInGroup(addr)
+}
+
+var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
+var _ stack.NetworkProtocol = (*protocol)(nil)
+
type protocol struct {
+ stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ eps map[*endpoint]struct{}
+ }
+
+ ids []uint32
+ hashIV uint32
+
// defaultTTL is the current default TTL for the protocol. Only the
- // uint8 portion of it is meaningful and it must be accessed
- // atomically.
- defaultTTL uint32
+ // uint8 portion of it is meaningful.
+ //
+ // Must be accessed using atomic operations.
+ defaultTTL uint32
+
+ // forwarding is set to 1 when the protocol has forwarding enabled and 0
+ // when it is disabled.
+ //
+ // Must be accessed using atomic operations.
+ forwarding uint32
+
fragmentation *fragmentation.Fragmentation
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // ndpConfigs is the default NDP configurations used by an IPv6 endpoint.
+ ndpConfigs NDPConfigurations
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // autoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
}
// Number returns the ipv6 protocol number.
@@ -541,16 +1290,35 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
- return &endpoint{
- nicID: nicID,
- linkEP: linkEP,
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
+ e := &endpoint{
+ nic: nic,
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
+ e.mu.addressableEndpointState.Init(e)
+ e.mu.ndp = ndpState{
+ ep: e,
+ configs: p.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ e.mu.ndp.initializeTempAddrState()
+
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.mu.eps[e] = struct{}{}
+ return e
+}
+
+func (p *protocol) forgetEndpoint(e *endpoint) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, e)
}
// SetOption implements NetworkProtocol.SetOption.
@@ -601,6 +1369,35 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
return proto, !fragMore && fragOffset == 0, true
}
+// Forwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) Forwarding() bool {
+ return uint8(atomic.LoadUint32(&p.forwarding)) == 1
+}
+
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.SwapUint32(&p.forwarding, 1) == 0
+ }
+ return atomic.SwapUint32(&p.forwarding, 0) == 1
+}
+
+// SetForwarding implements stack.ForwardingNetworkProtocol.
+func (p *protocol) SetForwarding(v bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.setForwarding(v) {
+ return
+ }
+
+ for ep := range p.mu.eps {
+ ep.transitionForwarding(v)
+ }
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
@@ -611,10 +1408,144 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-// NewProtocol returns an IPv6 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
- return &protocol{
- defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+// Options holds options to configure a new protocol.
+type Options struct {
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack attempts to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
+ // Note, setting this to true does not mean that a link-local address is
+ // assigned right away, or at all. If Duplicate Address Detection is enabled,
+ // an address is only assigned if it successfully resolves. If it fails, no
+ // further attempts are made to auto-generate an IPv6 link-local adddress.
+ //
+ // The generated link-local address follows RFC 4291 Appendix A guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
+}
+
+// NewProtocolWithOptions returns an IPv6 network protocol.
+func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
+ opts.NDPConfigs.validate()
+
+ ids := hash.RandN32(buckets)
+ hashIV := hash.RandN32(1)[0]
+
+ return func(s *stack.Stack) stack.NetworkProtocol {
+ p := &protocol{
+ stack: s,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ ids: ids,
+ hashIV: hashIV,
+
+ ndpDisp: opts.NDPDisp,
+ ndpConfigs: opts.NDPConfigs,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ }
+ p.mu.eps = make(map[*endpoint]struct{})
+ p.SetDefaultTTL(DefaultTTL)
+ return p
}
}
+
+// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options.
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return NewProtocolWithOptions(Options{})(s)
+}
+
+// calculateFragmentInnerMTU calculates the maximum number of bytes of
+// fragmentable data a fragment can have, based on the link layer mtu and pkt's
+// network header size.
+func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // MTU because they should only be transmitted once.
+ mtu -= uint32(pkt.NetworkHeader().View().Size())
+ mtu -= header.IPv6FragmentHeaderSize
+ // Round the MTU down to align to 8 bytes.
+ mtu &^= 7
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
+ return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address and 32-bit number to generate the hash.
+func hashRoute(r *stack.Route, hashIV uint32) uint32 {
+ // The FNV-1a was chosen because it is a fast hashing algorithm, and
+ // cryptographic properties are not needed here.
+ h := fnv.New32a()
+ if _, err := h.Write([]byte(r.LocalAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ if _, err := h.Write([]byte(r.RemoteAddress)); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected to ever return an error", err))
+ }
+
+ s := make([]byte, 4)
+ binary.LittleEndian.PutUint32(s, hashIV)
+ if _, err := h.Write(s); err != nil {
+ panic(fmt.Sprintf("Hash.Write: %s, but Hash' implementation of Write is not expected ever to return an error", err))
+ }
+
+ return h.Sum32()
+}
+
+func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders header.IPv6, transportProto tcpip.TransportProtocolNumber, id uint32) (*stack.PacketBuffer, bool) {
+ fragPkt, offset, copied, more := pf.BuildNextFragment()
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
+
+ originalIPHeadersLength := len(originalIPHeaders)
+ fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
+ fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+
+ // Copy the IPv6 header and any extension headers already populated.
+ if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
+ panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength))
+ }
+ fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader)
+ fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize))
+
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:])
+ fragmentHeader.Encode(&header.IPv6FragmentFields{
+ M: more,
+ FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ Identification: id,
+ NextHeader: uint8(transportProto),
+ })
+
+ return fragPkt, more
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 8ae146c5e..bee18d1a8 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -15,17 +15,21 @@
package ipv6
import (
+ "encoding/hex"
+ "fmt"
"math"
"testing"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -53,8 +57,8 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
t.Helper()
// Receive ICMP packet.
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertMinimumSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
@@ -136,6 +140,82 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
}
}
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
+ // sourcePacket does not have its IP Header populated. Let's copy the one
+ // from the first fragment.
+ source := header.IPv6(packets[0].NetworkHeader().View())
+ sourceIPHeadersLen := len(source)
+ vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
+ source = append(source, vv.ToView()...)
+
+ var reassembledPayload buffer.VectorisedView
+ for i, fragment := range packets {
+ // Confirm that the packet is valid.
+ allBytes := buffer.NewVectorisedView(fragment.Size(), fragment.Views())
+ fragmentIPHeaders := header.IPv6(allBytes.ToView())
+ if !fragmentIPHeaders.IsValid(len(fragmentIPHeaders)) {
+ return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeaders))
+ }
+
+ fragmentIPHeadersLength := fragment.NetworkHeader().View().Size()
+ if fragmentIPHeadersLength != sourceIPHeadersLen {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeadersLength = %d, want = %d", i, fragmentIPHeadersLength, sourceIPHeadersLen)
+ }
+
+ if got := len(fragmentIPHeaders); got > int(mtu) {
+ return fmt.Errorf("fragment #%d: got len(fragmentIPHeaders) = %d, want <= %d", i, got, mtu)
+ }
+
+ sourceIPHeader := source[:header.IPv6MinimumSize]
+ fragmentIPHeader := fragmentIPHeaders[:header.IPv6MinimumSize]
+
+ if got := fragmentIPHeaders.PayloadLength(); got != wantFragments[i].payloadSize {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeaders.PayloadLength() = %d, want = %d", i, got, wantFragments[i].payloadSize)
+ }
+
+ // We expect the IPv6 Header to be similar across each fragment, besides the
+ // payload length.
+ sourceIPHeader.SetPayloadLength(0)
+ fragmentIPHeader.SetPayloadLength(0)
+ if diff := cmp.Diff(fragmentIPHeader, sourceIPHeader); diff != "" {
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
+ }
+
+ if len(packets) > 1 {
+ // If the source packet was big enough that it needed fragmentation, let's
+ // inspect the fragment header. Because no other extension headers are
+ // supported, it will always be the last extension header.
+ fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[fragmentIPHeadersLength-header.IPv6FragmentHeaderSize : fragmentIPHeadersLength])
+
+ if got := fragmentHeader.More(); got != wantFragments[i].more {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.More() = %t, want = %t", i, got, wantFragments[i].more)
+ }
+ if got := fragmentHeader.FragmentOffset(); got != wantFragments[i].offset {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.FragmentOffset() = %d, want = %d", i, got, wantFragments[i].offset)
+ }
+ if got := fragmentHeader.NextHeader(); got != uint8(proto) {
+ return fmt.Errorf("fragment #%d: got fragmentHeader.NextHeader() = %d, want = %d", i, got, uint8(proto))
+ }
+ }
+
+ // Store the reassembled payload as we parse each fragment. The payload
+ // includes the Transport header and everything after.
+ reassembledPayload.AppendView(fragment.TransportHeader().View())
+ reassembledPayload.Append(fragment.Data)
+ }
+
+ result := reassembledPayload.ToView()
+ if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+
+ return nil
+}
+
// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
@@ -170,8 +250,6 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
protocolFactory stack.TransportProtocolFactory
@@ -195,7 +273,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
}
s.SetRouteTable([]tcpip.Route{
- tcpip.Route{
+ {
Destination: header.IPv6EmptySubnet,
NIC: nicID,
},
@@ -295,17 +373,22 @@ func TestAddIpv6Address(t *testing.T) {
}
func TestReceiveIPv6ExtHdrs(t *testing.T) {
- const nicID = 1
-
tests := []struct {
name string
extHdr func(nextHdr uint8) ([]byte, uint8)
shouldAccept bool
+ // Should we expect an ICMP response and if so, with what contents?
+ expectICMP bool
+ ICMPType header.ICMPv6Type
+ ICMPCode header.ICMPv6Code
+ pointer uint32
+ multicast bool
}{
{
name: "None",
extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
shouldAccept: true,
+ expectICMP: false,
},
{
name: "hopbyhop with unknown option skippable action",
@@ -336,9 +419,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "hopbyhop with unknown option discard and send icmp action",
+ name: "hopbyhop with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -348,12 +432,58 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ name: "hopbyhop with unknown option discard and send icmp action (multicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (unicast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
+ }, hopByHopExtHdrID
+ },
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -364,39 +494,77 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ Unknown option.
}, hopByHopExtHdrID
},
+ multicast: true,
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "routing with zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: true,
},
{
- name: "routing with non-zero segments left",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 1, 2, 3, 4, 5,
+ }, routingExtHdrID
+ },
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6FixedHeaderSize + 2,
},
{
- name: "atomic fragment with zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 0, 0, 0, 0,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
},
{
- name: "atomic fragment with non-zero ID",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 0, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: true,
+ expectICMP: false,
},
{
- name: "fragment",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 0,
+ 1, 0, 1, 2, 3, 4,
+ }, fragmentExtHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "No next header",
- extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{},
+ noNextHdrID
+ },
shouldAccept: false,
+ expectICMP: false,
},
{
name: "destination with unknown option skippable action",
@@ -412,6 +580,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: true,
+ expectICMP: false,
},
{
name: "destination with unknown option discard action",
@@ -427,9 +596,10 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
- name: "destination with unknown option discard and send icmp action",
+ name: "destination with unknown option discard and send icmp action (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -439,12 +609,38 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP if option is unknown.
191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "destination with unknown option discard and send icmp action unless multicast dest",
+ name: "destination with unknown option discard and send icmp action (muilticast)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ //^ 191 is an unknown option.
+ }, destinationExtHdrID
+ },
+ multicast: true,
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest (unicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
nextHdr, 1,
@@ -455,22 +651,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Discard & send ICMP unless packet is for multicast destination if
// option is unknown.
255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
}, destinationExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownOption,
+ pointer: header.IPv6FixedHeaderSize + 8,
},
{
- name: "routing - atomic fragment",
+ name: "destination with unknown option discard and send icmp action unless multicast dest (multicast)",
extHdr: func(nextHdr uint8) ([]byte, uint8) {
return []byte{
- // Routing extension header.
- fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ nextHdr, 1,
- // Fragment extension header.
- nextHdr, 0, 0, 0, 1, 2, 3, 4,
- }, routingExtHdrID
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ //^ 255 is unknown.
+ }, destinationExtHdrID
},
- shouldAccept: true,
+ shouldAccept: false,
+ expectICMP: false,
+ multicast: true,
},
{
name: "atomic fragment - routing",
@@ -504,12 +711,42 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
return []byte{
// Routing extension header.
hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
// Hop By Hop extension header with skippable unknown option.
nextHdr, 0, 62, 4, 1, 2, 3, 4,
}, routingExtHdrID
},
shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
+ },
+ {
+ name: "routing - hop by hop (with send icmp unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+ // ^^^ The HopByHop extension header may not appear after the first
+ // extension header.
+
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ expectICMP: true,
+ ICMPType: header.ICMPv6ParamProblem,
+ ICMPCode: header.ICMPv6UnknownHeader,
+ pointer: header.IPv6FixedHeaderSize,
},
{
name: "No next header",
@@ -553,6 +790,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
{
name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
@@ -573,6 +811,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}, hopByHopExtHdrID
},
shouldAccept: false,
+ expectICMP: false,
},
}
@@ -582,7 +821,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(1, 1280, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -590,6 +829,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ // Add a default route so that a return packet knows where to go.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -631,12 +878,16 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
// Serialize IPv6 fixed header.
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ dstAddr := tcpip.Address(addr2)
+ if test.multicast {
+ dstAddr = header.IPv6AllNodesMulticastAddress
+ }
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(payloadLength),
NextHeader: ipv6NextHdr,
HopLimit: 255,
SrcAddr: addr1,
- DstAddr: addr2,
+ DstAddr: dstAddr,
})
e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -650,6 +901,44 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 0", got)
}
+ if !test.expectICMP {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpected packet received: %#v", p)
+ }
+ return
+ }
+
+ // ICMP required.
+ p, ok := e.Read()
+ if !ok {
+ t.Fatalf("expected packet wasn't written out")
+ }
+
+ // Pack the output packet into a single buffer.View as the checkers
+ // assume that.
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6FixedHeaderSize+header.ICMPv6MinimumSize+hdr.UsedLength(); got != want {
+ t.Fatalf("got an ICMP packet of size = %d, want = %d", got, want)
+ }
+
+ ipHdr := header.IPv6(pkt)
+ checker.IPv6(t, ipHdr, checker.ICMPv6(
+ checker.ICMPv6Type(test.ICMPType),
+ checker.ICMPv6Code(test.ICMPCode)))
+
+ // We know we are looking at no extension headers in the error ICMP
+ // packets.
+ icmpPkt := header.ICMPv6(ipHdr.Payload())
+ // We know we sent small packets that won't be truncated when reflected
+ // back to us.
+ originalPacket := icmpPkt.Payload()
+ if got, want := icmpPkt.TypeSpecific(), test.pointer; got != want {
+ t.Errorf("unexpected ICMPv6 pointer, got = %d, want = %d\n", got, want)
+ }
+ if diff := cmp.Diff(hdr.View(), buffer.View(originalPacket)); diff != "" {
+ t.Errorf("ICMPv6 payload mismatch (-want +got):\n%s", diff)
+ }
return
}
@@ -683,7 +972,6 @@ type fragmentData struct {
func TestReceiveIPv6Fragments(t *testing.T) {
const (
- nicID = 1
udpPayload1Length = 256
udpPayload2Length = 128
// Used to test cases where the fragment blocks are not a multiple of
@@ -1748,7 +2036,7 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to find filter table")
}
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
@@ -1770,10 +2058,10 @@ func TestWriteStats(t *testing.T) {
}
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
- filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
- filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
@@ -1815,7 +2103,6 @@ func TestWriteStats(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
-
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1857,12 +2144,13 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
)
if err := s.AddAddress(1, ProtocolNumber, src); err != nil {
- t.Fatalf("AddAddress(1, %d, _) failed: %s", ProtocolNumber, err)
+ t.Fatalf("AddAddress(1, %d, %s) failed: %s", ProtocolNumber, src, err)
}
{
- subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ mask := tcpip.AddressMask("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")
+ subnet, err := tcpip.NewSubnet(dst, mask)
if err != nil {
- t.Fatalf("NewSubnet(_, _) failed: %v", err)
+ t.Fatalf("NewSubnet(%s, %s) failed: %v", dst, mask, err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: subnet,
@@ -1871,7 +2159,7 @@ func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route {
}
rt, err := s.FindRoute(1, src, dst, ProtocolNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("got FindRoute(1, _, _, %d, false) = %s, want = nil", ProtocolNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d, false) = %s, want = nil", src, dst, ProtocolNumber, err)
}
return rt
}
@@ -1895,3 +2183,320 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
lm.limit--
return false, false
}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if !hasEP {
+ t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
+ }
+ }
+
+ ep.Close()
+
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[ep]
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
+ }
+ }
+}
+
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
+
+type fragmentationTestCase struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transHdrLen int
+ extraHdrLen int
+ payloadSize int
+ wantFragments []fragmentInfo
+ expectedFrags int
+}
+
+var fragmentationTests = []fragmentationTestCase{
+ {
+ description: "No Fragmentation",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso nil",
+ mtu: 1280,
+ gso: nil,
+ transHdrLen: 0,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 176, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 100,
+ extraHdrLen: header.IPv6MinimumSize,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 76, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header and prependable bytes",
+ mtu: 1280,
+ gso: &stack.GSO{},
+ transHdrLen: 20,
+ extraHdrLen: header.IPv6MinimumSize + 66,
+ payloadSize: 1500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 296, more: false},
+ },
+ },
+}
+
+func TestFragmentation(t *testing.T) {
+ const (
+ ttl = 42
+ tos = stack.DefaultTOS
+ transportProto = tcp.ProtocolNumber
+ )
+
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt.Clone()
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != nil {
+ t.Fatalf("WritePacket(_, _, _): = %s", err)
+ }
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+ if len(ep.WrittenPackets) > 0 {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ tests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ source := pkt
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if n != wantTotalPackets || err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (%d, %s), want = (%d, nil)", n, err, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
+// correctly.
+func TestFragmentationErrors(t *testing.T) {
+ const ttl = 42
+
+ tests := []struct {
+ description string
+ mtu uint32
+ transHdrLen int
+ payloadSize int
+ allowPackets int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
+ }{
+ {
+ description: "No frag",
+ mtu: 2000,
+ payloadSize: 1000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on first frag",
+ mtu: 1300,
+ payloadSize: 3000,
+ transHdrLen: 0,
+ allowPackets: 0,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on second frag",
+ mtu: 1500,
+ payloadSize: 4000,
+ transHdrLen: 0,
+ allowPackets: 1,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error on packet with MTU smaller than transport header",
+ mtu: 1280,
+ transHdrLen: 1500,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrMessageTooLong,
+ },
+ }
+
+ for _, ft := range tests {
+ t.Run(ft.description, func(t *testing.T) {
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
+ err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ }, pkt)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
new file mode 100644
index 000000000..40da011f8
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -0,0 +1,2013 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "fmt"
+ "log"
+ "math/rand"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when an IPv6 endpoint becomes enabled.
+ //
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
+
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
+ //
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
+ // defaultHandleRAs is the default configuration for whether or not to
+ // handle incoming Router Advertisements as a host.
+ defaultHandleRAs = true
+
+ // defaultDiscoverDefaultRouters is the default configuration for
+ // whether or not to discover default routers from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverDefaultRouters = true
+
+ // defaultDiscoverOnLinkPrefixes is the default configuration for
+ // whether or not to discover on-link prefixes from incoming Router
+ // Advertisements' Prefix Information option, as a host.
+ defaultDiscoverOnLinkPrefixes = true
+
+ // defaultAutoGenGlobalAddresses is the default configuration for
+ // whether or not to generate global IPv6 addresses in response to
+ // receiving a new Prefix Information option with its Autonomous
+ // Address AutoConfiguration flag set, as a host.
+ //
+ // Default = true.
+ defaultAutoGenGlobalAddresses = true
+
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
+ // MaxDiscoveredDefaultRouters is the maximum number of discovered
+ // default routers. The stack should stop discovering new routers after
+ // discovering MaxDiscoveredDefaultRouters routers.
+ //
+ // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
+ // SHOULD be more.
+ MaxDiscoveredDefaultRouters = 10
+
+ // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
+ // on-link prefixes. The stack should stop discovering new on-link
+ // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
+ // prefixes.
+ MaxDiscoveredOnLinkPrefixes = 10
+
+ // validPrefixLenForAutoGen is the expected prefix length that an
+ // address can be generated for. Must be 64 bits as the interface
+ // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
+ // 128 - 64 = 64.
+ validPrefixLenForAutoGen = 64
+
+ // defaultAutoGenTempGlobalAddresses is the default configuration for whether
+ // or not to generate temporary SLAAC addresses.
+ defaultAutoGenTempGlobalAddresses = true
+
+ // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 7 days (from RFC 4941 section 5).
+ defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour
+
+ // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 1 day (from RFC 4941 section 5).
+ defaultMaxTempAddrPreferredLifetime = 24 * time.Hour
+
+ // defaultRegenAdvanceDuration is the default duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ //
+ // Default = 5s (from RFC 4941 section 5).
+ defaultRegenAdvanceDuration = 5 * time.Second
+
+ // minRegenAdvanceDuration is the minimum duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ minRegenAdvanceDuration = time.Duration(0)
+
+ // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
+ // SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
+ maxSLAACAddrLocalRegenAttempts = 10
+)
+
+var (
+ // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
+ // Lifetime to update the valid lifetime of a generated address by
+ // SLAAC.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Min = 2hrs.
+ MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
+
+ // MaxDesyncFactor is the upper bound for the preferred lifetime's desync
+ // factor for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Must be greater than 0.
+ //
+ // Max = 10m (from RFC 4941 section 5).
+ MaxDesyncFactor = 10 * time.Minute
+
+ // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
+ // maximum preferred lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address is preferred for at
+ // least 1hr if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
+
+ // MinMaxTempAddrValidLifetime is the minimum value allowed for the
+ // maximum valid lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address is valid for at least
+ // 2hrs if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrValidLifetime = 2 * time.Hour
+)
+
+// NDPEndpoint is an endpoint that supports NDP.
+type NDPEndpoint interface {
+ // SetNDPConfigurations sets the NDP configurations.
+ SetNDPConfigurations(NDPConfigurations)
+}
+
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ _ DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // returns all available configuration information when serving addresses.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus is called when the DAD process for an
+ // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
+ // if DAD completed successfully (no duplicate addr detected); false otherwise
+ // (addr was detected to be a duplicate on the link the NIC is a part of, or
+ // it was stopped for some other reason, such as the address being removed).
+ // If an error occured during DAD, err is set and resolved must be ignored.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+
+ // OnDefaultRouterDiscovered is called when a new default router is
+ // discovered. Implementations must return true if the newly discovered
+ // router should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+
+ // OnDefaultRouterInvalidated is called when a discovered default router that
+ // was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+
+ // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
+ // Implementations must return true if the newly discovered on-link prefix
+ // should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+
+ // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
+ // was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+
+ // OnAutoGenAddress is called when a new prefix with its autonomous address-
+ // configuration flag set is received and SLAAC was performed. Implementations
+ // may prevent the stack from assigning the address to the NIC by returning
+ // false.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+
+ // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
+ // is deprecated, but is still considered valid. Note, if an address is
+ // invalidated at the same ime it is deprecated, the deprecation event may not
+ // be received.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnAutoGenAddressInvalidated is called when an auto-generated address
+ // (SLAAC) is invalidated.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnRecursiveDNSServerOption is called when the stack learns of DNS servers
+ // through NDP. Note, the addresses may contain link-local addresses.
+ //
+ // It is up to the caller to use the DNS Servers only for their valid
+ // lifetime. OnRecursiveDNSServerOption may be called for new or
+ // already known DNS servers. If called with known DNS servers, their
+ // valid lifetimes must be refreshed to lifetime (it may be increased,
+ // decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDNSSearchListOption is called when the stack learns of DNS search lists
+ // through NDP.
+ //
+ // It is up to the caller to use the domain names in the search list
+ // for only their valid lifetime. OnDNSSearchListOption may be called
+ // with new or already known domain names. If called with known domain
+ // names, their valid lifetimes must be refreshed to lifetime (it may
+ // be increased, decreased or completely invalidated when lifetime = 0.
+ OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+
+ // OnDHCPv6Configuration is called with an updated configuration that is
+ // available via DHCPv6 for the passed NIC.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
+}
+
+// NDPConfigurations is the NDP configurations for the netstack.
+type NDPConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor solicitation
+ // messages.
+ //
+ // Must be greater than or equal to 1ms.
+ RetransmitTimer time.Duration
+
+ // The number of Router Solicitation messages to send when the IPv6 endpoint
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
+ // HandleRAs determines whether or not Router Advertisements are processed.
+ HandleRAs bool
+
+ // DiscoverDefaultRouters determines whether or not default routers are
+ // discovered from Router Advertisements, as per RFC 4861 section 6. This
+ // configuration is ignored if HandleRAs is false.
+ DiscoverDefaultRouters bool
+
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
+ // discovered from Router Advertisements' Prefix Information option, as per
+ // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
+ DiscoverOnLinkPrefixes bool
+
+ // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
+ // SLAAC to auto-generate global SLAAC addresses in response to Prefix
+ // Information options, as per RFC 4862.
+ //
+ // Note, if an address was already generated for some unique prefix, as
+ // part of SLAAC, this option does not affect whether or not the
+ // lifetime(s) of the generated address changes; this option only
+ // affects the generation of new addresses as part of SLAAC.
+ AutoGenGlobalAddresses bool
+
+ // AutoGenAddressConflictRetries determines how many times to attempt to retry
+ // generation of a permanent auto-generated address in response to DAD
+ // conflicts.
+ //
+ // If the method used to generate the address does not support creating
+ // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
+ // MAC address), then no attempt is made to resolve the conflict.
+ AutoGenAddressConflictRetries uint8
+
+ // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
+ // addresses are generated for an IPv6 endpoint as part of SLAAC privacy
+ // extensions, as per RFC 4941.
+ //
+ // Ignored if AutoGenGlobalAddresses is false.
+ AutoGenTempGlobalAddresses bool
+
+ // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary
+ // SLAAC addresses.
+ MaxTempAddrValidLifetime time.Duration
+
+ // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for
+ // temporary SLAAC addresses.
+ MaxTempAddrPreferredLifetime time.Duration
+
+ // RegenAdvanceDuration is the duration before the deprecation of a temporary
+ // address when a new address will be generated.
+ RegenAdvanceDuration time.Duration
+}
+
+// DefaultNDPConfigurations returns an NDPConfigurations populated with
+// default values.
+func DefaultNDPConfigurations() NDPConfigurations {
+ return NDPConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
+ MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime,
+ RegenAdvanceDuration: defaultRegenAdvanceDuration,
+ }
+}
+
+// validate modifies an NDPConfigurations with valid values. If invalid values
+// are present in c, the corresponding default values are used instead.
+func (c *NDPConfigurations) validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
+
+ if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime {
+ c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime
+ }
+
+ if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime {
+ c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime
+ }
+
+ if c.RegenAdvanceDuration < minRegenAdvanceDuration {
+ c.RegenAdvanceDuration = minRegenAdvanceDuration
+ }
+}
+
+// ndpState is the per-interface NDP state.
+type ndpState struct {
+ // The IPv6 endpoint this ndpState is for.
+ ep *endpoint
+
+ // configs is the per-interface NDP configurations.
+ configs NDPConfigurations
+
+ // The DAD state to send the next NS message, or resolve the address.
+ dad map[tcpip.Address]dadState
+
+ // The default routers discovered through Router Advertisements.
+ defaultRouters map[tcpip.Address]defaultRouterState
+
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer tcpip.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the IPv6 endpoint this ndpState is associated with. MUST be set when the
+ // timer is set.
+ done *bool
+ }
+
+ // The on-link prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
+
+ // temporaryIIDHistory is the history value used to generate a new temporary
+ // IID.
+ temporaryIIDHistory [header.IIDSize]byte
+
+ // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for
+ // temporary SLAAC addresses.
+ temporaryAddressDesyncFactor time.Duration
+}
+
+// dadState holds the Duplicate Address Detection timer and channel to signal
+// to the DAD goroutine that DAD should stop.
+type dadState struct {
+ // The DAD timer to send the next NS message, or resolve the address.
+ timer tcpip.Timer
+
+ // Used to let the DAD timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the IPv6 endpoint this dadState is associated with.
+ done *bool
+}
+
+// defaultRouterState holds data associated with a default router discovered by
+// a Router Advertisement (RA).
+type defaultRouterState struct {
+ // Job to invalidate the default router.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+}
+
+// onLinkPrefixState holds data associated with an on-link prefix discovered by
+// a Router Advertisement's Prefix Information option (PI) when the NDP
+// configurations was configured to do so.
+type onLinkPrefixState struct {
+ // Job to invalidate the on-link prefix.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+}
+
+// tempSLAACAddrState holds state associated with a temporary SLAAC address.
+type tempSLAACAddrState struct {
+ // Job to deprecate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ deprecationJob *tcpip.Job
+
+ // Job to invalidate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+
+ // Job to regenerate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ regenJob *tcpip.Job
+
+ createdAt time.Time
+
+ // The address's endpoint.
+ //
+ // Must not be nil.
+ addressEndpoint stack.AddressEndpoint
+
+ // Has a new temporary SLAAC address already been regenerated?
+ regenerated bool
+}
+
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
+ // Job to deprecate the prefix.
+ //
+ // Must not be nil.
+ deprecationJob *tcpip.Job
+
+ // Job to invalidate the prefix.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+
+ // Nonzero only when the address is not valid forever.
+ validUntil time.Time
+
+ // Nonzero only when the address is not preferred forever.
+ preferredUntil time.Time
+
+ // State associated with the stable address generated for the prefix.
+ stableAddr struct {
+ // The address's endpoint.
+ //
+ // May only be nil when the address is being (re-)generated. Otherwise,
+ // must not be nil as all SLAAC prefixes must have a stable address.
+ addressEndpoint stack.AddressEndpoint
+
+ // The number of times an address has been generated locally where the IPv6
+ // endpoint already had the generated address.
+ localGenerationFailures uint8
+ }
+
+ // The temporary (short-lived) addresses generated for the SLAAC prefix.
+ tempAddrs map[tcpip.Address]tempSLAACAddrState
+
+ // The next two fields are used by both stable and temporary addresses
+ // generated for a SLAAC prefix. This is safe as only 1 address is in the
+ // generation and DAD process at any time. That is, no two addresses are
+ // generated at the same time for a given SLAAC prefix.
+
+ // The number of times an address has been generated and added to the IPv6
+ // endpoint.
+ //
+ // Addresses may be regenerated in reseponse to a DAD conflicts.
+ generationAttempts uint8
+
+ // The maximum number of times to attempt regeneration of a SLAAC address
+ // in response to DAD conflicts.
+ maxGenerationAttempts uint8
+}
+
+// startDuplicateAddressDetection performs Duplicate Address Detection.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+ // addr must be a valid unicast IPv6 address.
+ if !header.IsV6UnicastAddress(addr) {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should be marked as tentative since we are starting DAD.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in the
+ // DAD process.
+ if _, ok := ndp.dad[addr]; ok {
+ // Should never happen because we should only ever call this function for
+ // newly created addresses. If we attemped to "add" an address that already
+ // existed, we would get an error since we attempted to add a duplicate
+ // address, or its reference count would have been increased without doing
+ // the work that would have been done for an address that was brand new.
+ // See endpoint.addAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
+
+ remaining := ndp.configs.DupAddrDetectTransmits
+ if remaining == 0 {
+ addressEndpoint.SetKind(stack.Permanent)
+
+ // Consider DAD to have resolved even if no DAD messages were actually
+ // transmitted.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
+ }
+
+ return nil
+ }
+
+ var done bool
+ var timer tcpip.Timer
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the IPv6 endpoint's lock. This is effectively
+ // the same as starting a goroutine but we use a timer that fires immediately
+ // so we can reset it for the next DAD iteration.
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
+ ndp.ep.mu.Lock()
+ defer ndp.ep.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD timer fired after
+ // another goroutine already obtained the IPv6 endpoint lock and stopped
+ // DAD before this function obtained the NIC lock. Simply return here and
+ // do nothing further.
+ return
+ }
+
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
+ }
+
+ dadDone := remaining == 0
+
+ var err *tcpip.Error
+ if !dadDone {
+ // Use the unspecified address as the source address when performing DAD.
+ addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
+
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the IPv6 endpoint. Note,
+ // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if
+ // the address was removed.
+ ndp.ep.mu.Unlock()
+ err = ndp.sendDADPacket(addr, addressEndpoint)
+ ndp.ep.mu.Lock()
+ addressEndpoint.DecRef()
+ }
+
+ if done {
+ // If we reach this point, it means that DAD was stopped after we released
+ // the IPv6 endpoint's read lock and before we obtained the write lock.
+ return
+ }
+
+ if dadDone {
+ // DAD has resolved.
+ addressEndpoint.SetKind(stack.Permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
+ remaining--
+ timer.Reset(ndp.configs.RetransmitTimer)
+ return
+ }
+
+ // At this point we know that either DAD is done or we hit an error sending
+ // the last NDP NS. Either way, clean up addr's DAD state and let the
+ // integrator know DAD has completed.
+ delete(ndp.dad, addr)
+
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
+ }
+
+ // If DAD resolved for a stable SLAAC address, attempt generation of a
+ // temporary SLAAC address.
+ if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */)
+ }
+ })
+
+ ndp.dad[addr] = dadState{
+ timer: timer,
+ done: &done,
+ }
+
+ return nil
+}
+
+// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
+// addr.
+//
+// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
+//
+// The IPv6 endpoint that ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+ snmc := header.SolicitedNodeAddr(addr)
+
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ // Route should resolve immediately since snmc is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the IPv6 endpoint is not
+ // locked, the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
+ }
+
+ icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmpData.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
+ ns.SetTargetAddress(addr)
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ return err
+ }
+ sent.NeighborSolicit.Increment()
+
+ return nil
+}
+
+// stopDuplicateAddressDetection ends a running Duplicate Address Detection
+// process. Note, this may leave the DAD process for a tentative address in
+// such a state forever, unless some other external event resolves the DAD
+// process (receiving an NA from the true owner of addr, or an NS for addr
+// (implying another node is attempting to use addr)). It is up to the caller
+// of this function to handle such a scenario.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
+ dad, ok := ndp.dad[addr]
+ if !ok {
+ // Not currently performing DAD on addr, just return.
+ return
+ }
+
+ if dad.timer != nil {
+ dad.timer.Stop()
+ dad.timer = nil
+
+ *dad.done = true
+ dad.done = nil
+ }
+
+ delete(ndp.dad, addr)
+
+ // Let the integrator know DAD did not resolve.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
+ }
+}
+
+// handleRA handles a Router Advertisement message that arrived on the NIC
+// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ // Is the IPv6 endpoint configured to handle RAs at all?
+ //
+ // Currently, the stack does not determine router interface status on a
+ // per-interface basis; it is a protocol-wide configuration, so we check the
+ // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding
+ // packets.
+ if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() {
+ return
+ }
+
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
+ }
+ }
+
+ // Is the IPv6 endpoint configured to discover default routers?
+ if ndp.configs.DiscoverDefaultRouters {
+ rtr, ok := ndp.defaultRouters[ip]
+ rl := ra.RouterLifetime()
+ switch {
+ case !ok && rl != 0:
+ // This is a new default router we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredDefaultRouters routers.
+ if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
+ ndp.rememberDefaultRouter(ip, rl)
+ }
+
+ case ok && rl != 0:
+ // This is an already discovered default router. Update
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
+ ndp.defaultRouters[ip] = rtr
+
+ case ok && rl == 0:
+ // We know about the router but it is no longer to be
+ // used as a default router so invalidate it.
+ ndp.invalidateDefaultRouter(ip)
+ }
+ }
+
+ // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
+ // Discovery.
+
+ // We know the options is valid as far as wire format is concerned since
+ // we got the Router Advertisement, as documented by this fn. Given this
+ // we do not check the iterator for errors on calls to Next.
+ it, _ := ra.Options().Iter(false)
+ for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
+ switch opt := opt.(type) {
+ case header.NDPRecursiveDNSServer:
+ if ndp.ep.protocol.ndpDisp == nil {
+ continue
+ }
+
+ addrs, _ := opt.Addresses()
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
+
+ case header.NDPDNSSearchList:
+ if ndp.ep.protocol.ndpDisp == nil {
+ continue
+ }
+
+ domainNames, _ := opt.DomainNames()
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
+
+ case header.NDPPrefixInformation:
+ prefix := opt.Subnet()
+
+ // Is the prefix a link-local?
+ if header.IsV6LinkLocalAddress(prefix.ID()) {
+ // ...Yes, skip as per RFC 4861 section 6.3.4,
+ // and RFC 4862 section 5.5.3.b (for SLAAC).
+ continue
+ }
+
+ // Is the Prefix Length 0?
+ if prefix.Prefix() == 0 {
+ // ...Yes, skip as this is an invalid prefix
+ // as all IPv6 addresses cannot be on-link.
+ continue
+ }
+
+ if opt.OnLinkFlag() {
+ ndp.handleOnLinkPrefixInformation(opt)
+ }
+
+ if opt.AutonomousAddressConfigurationFlag() {
+ ndp.handleAutonomousPrefixInformation(opt)
+ }
+ }
+
+ // TODO(b/141556115): Do (MTU) Parameter Discovery.
+ }
+}
+
+// invalidateDefaultRouter invalidates a discovered default router.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
+ rtr, ok := ndp.defaultRouters[ip]
+
+ // Is the router still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ rtr.invalidationJob.Cancel()
+ delete(ndp.defaultRouters, ip)
+
+ // Let the integrator know a discovered default router is invalidated.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
+ }
+}
+
+// rememberDefaultRouter remembers a newly discovered default router with IPv6
+// link-local address ip with lifetime rl.
+//
+// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+ ndpDisp := ndp.ep.protocol.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered a default router.
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
+ // Informed by the integrator to not remember the router, do
+ // nothing further.
+ return
+ }
+
+ state := defaultRouterState{
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateDefaultRouter(ip)
+ }),
+ }
+
+ state.invalidationJob.Schedule(rl)
+
+ ndp.defaultRouters[ip] = state
+}
+
+// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
+// address with prefix prefix with lifetime l.
+//
+// The prefix identified by prefix MUST NOT already be known.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
+ ndpDisp := ndp.ep.protocol.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered an on-link prefix.
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
+ // Informed by the integrator to not remember the prefix, do
+ // nothing further.
+ return
+ }
+
+ state := onLinkPrefixState{
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
+ }
+
+ if l < header.NDPInfiniteLifetime {
+ state.invalidationJob.Schedule(l)
+ }
+
+ ndp.onLinkPrefixes[prefix] = state
+}
+
+// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
+ s, ok := ndp.onLinkPrefixes[prefix]
+
+ // Is the on-link prefix still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ s.invalidationJob.Cancel()
+ delete(ndp.onLinkPrefixes, prefix)
+
+ // Let the integrator know a discovered on-link prefix is invalidated.
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
+ }
+}
+
+// handleOnLinkPrefixInformation handles a Prefix Information option with
+// its on-link flag set, as per RFC 4861 section 6.3.4.
+//
+// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the on-link flag is set.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
+ prefix := pi.Subnet()
+ prefixState, ok := ndp.onLinkPrefixes[prefix]
+ vl := pi.ValidLifetime()
+
+ if !ok && vl == 0 {
+ // Don't know about this prefix but it has a zero valid
+ // lifetime, so just ignore.
+ return
+ }
+
+ if !ok && vl != 0 {
+ // This is a new on-link prefix we are discovering
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOnLinkPrefixes on-link prefixes.
+ if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
+ ndp.rememberOnLinkPrefix(prefix, vl)
+ }
+ return
+ }
+
+ if ok && vl == 0 {
+ // We know about the on-link prefix, but it is
+ // no longer to be considered on-link, so
+ // invalidate it.
+ ndp.invalidateOnLinkPrefix(prefix)
+ return
+ }
+
+ // This is an already discovered on-link prefix with a
+ // new non-zero valid lifetime.
+ //
+ // Update the invalidation job.
+
+ prefixState.invalidationJob.Cancel()
+
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
+ // the new valid lifetime.
+ prefixState.invalidationJob.Schedule(vl)
+ }
+
+ ndp.onLinkPrefixes[prefix] = prefixState
+}
+
+// handleAutonomousPrefixInformation handles a Prefix Information option with
+// its autonomous flag set, as per RFC 4862 section 5.5.3.
+//
+// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the autonomous flag is set.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
+ vl := pi.ValidLifetime()
+ pl := pi.PreferredLifetime()
+
+ // If the preferred lifetime is greater than the valid lifetime,
+ // silently ignore the Prefix Information option, as per RFC 4862
+ // section 5.5.3.c.
+ if pl > vl {
+ return
+ }
+
+ prefix := pi.Subnet()
+
+ // Check if we already maintain SLAAC state for prefix.
+ if state, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl)
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ ndp.doSLAAC(prefix, pl, vl)
+}
+
+// doSLAAC generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
+ // If we do not already have an address for this prefix and the valid
+ // lifetime is 0, no need to do anything further, as per RFC 4862
+ // section 5.5.3.d.
+ if vl == 0 {
+ return
+ }
+
+ // Make sure the prefix is valid (as far as its length is concerned) to
+ // generate a valid IPv6 address from an interface identifier (IID), as
+ // per RFC 4862 sectiion 5.5.3.d.
+ if prefix.Prefix() != validPrefixLenForAutoGen {
+ return
+ }
+
+ state := slaacPrefixState{
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
+ }
+
+ ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
+ }),
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }),
+ tempAddrs: make(map[tcpip.Address]tempSLAACAddrState),
+ maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
+ }
+
+ now := time.Now()
+
+ // The time an address is preferred until is needed to properly generate the
+ // address.
+ if pl < header.NDPInfiniteLifetime {
+ state.preferredUntil = now.Add(pl)
+ }
+
+ if !ndp.generateSLAACAddr(prefix, &state) {
+ // We were unable to generate an address for the prefix, we do not nothing
+ // further as there is no reason to maintain state or jobs for a prefix we
+ // do not have an address for.
+ return
+ }
+
+ // Setup the initial jobs to deprecate and invalidate prefix.
+
+ if pl < header.NDPInfiniteLifetime && pl != 0 {
+ state.deprecationJob.Schedule(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationJob.Schedule(vl)
+ state.validUntil = now.Add(vl)
+ }
+
+ // If the address is assigned (DAD resolved), generate a temporary address.
+ if state.stableAddr.addressEndpoint.GetKind() == stack.Permanent {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// addAndAcquireSLAACAddr adds a SLAAC address to the IPv6 endpoint.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint {
+ // Inform the integrator that we have a new SLAAC address.
+ ndpDisp := ndp.ep.protocol.ndpDisp
+ if ndpDisp == nil {
+ return nil
+ }
+
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
+ // Informed by the integrator not to add the address.
+ return nil
+ }
+
+ addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
+ if err != nil {
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
+ }
+
+ return addressEndpoint
+}
+
+// generateSLAACAddr generates a SLAAC address for prefix.
+//
+// Returns true if an address was successfully generated.
+//
+// Panics if the prefix is not a SLAAC prefix or it already has an address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, addressEndpoint.AddressWithPrefix()))
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if state.generationAttempts == state.maxGenerationAttempts {
+ return false
+ }
+
+ var generatedAddr tcpip.AddressWithPrefix
+ addrBytes := []byte(prefix.ID())
+
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
+ if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
+ dadCounter,
+ oIID.SecretKey,
+ )
+ } else if dadCounter == 0 {
+ // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if
+ // the DAD counter is non-zero, we cannot use this method.
+ //
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.ep.nic.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return false
+ }
+
+ // Generate an address within prefix from the modified EUI-64 of ndp's
+ // NIC's Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ } else {
+ // We have no way to regenerate an address in response to an address
+ // conflict when addresses are not generated with opaque IIDs.
+ return false
+ }
+
+ generatedAddr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ }
+
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
+ break
+ }
+
+ state.stableAddr.localGenerationFailures++
+ }
+
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ state.stableAddr.addressEndpoint = addressEndpoint
+ state.generationAttempts++
+ return true
+ }
+
+ return false
+}
+
+// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
+//
+// If generating a new address for the prefix fails, the prefix is invalidated.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix))
+ }
+
+ if ndp.generateSLAACAddr(prefix, &state) {
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // We were unable to generate a permanent address for the SLAAC prefix so
+ // invalidate the prefix as there is no reason to maintain state for a
+ // SLAAC prefix we do not have an address for.
+ ndp.invalidateSLAACPrefix(prefix, state)
+}
+
+// generateTempSLAACAddr generates a new temporary SLAAC address.
+//
+// If resetGenAttempts is true, the prefix's generation counter is reset.
+//
+// Returns true if a new address was generated.
+func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
+ // Are we configured to auto-generate new temporary global addresses for the
+ // prefix?
+ if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() {
+ return false
+ }
+
+ if resetGenAttempts {
+ prefixState.generationAttempts = 0
+ prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if prefixState.generationAttempts == prefixState.maxGenerationAttempts {
+ return false
+ }
+
+ stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address
+ now := time.Now()
+
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime.
+ vl := ndp.configs.MaxTempAddrValidLifetime
+ if prefixState.validUntil != (time.Time{}) {
+ if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
+ vl = prefixVL
+ }
+ }
+
+ if vl <= 0 {
+ // Cannot create an address without a valid lifetime.
+ return false
+ }
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or the
+ // maximum temporary address preferred lifetime - the temporary address desync
+ // factor.
+ pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
+ if prefixState.preferredUntil != (time.Time{}) {
+ if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
+ // Respect the preferred lifetime of the prefix, as per RFC 4941 section
+ // 3.3 step 4.
+ pl = prefixPL
+ }
+ }
+
+ // As per RFC 4941 section 3.3 step 5, a temporary address is created only if
+ // the calculated preferred lifetime is greater than the advance regeneration
+ // duration. In particular, we MUST NOT create a temporary address with a zero
+ // Preferred Lifetime.
+ if pl <= ndp.configs.RegenAdvanceDuration {
+ return false
+ }
+
+ // Attempt to generate a new address that is not already assigned to the IPv6
+ // endpoint.
+ var generatedAddr tcpip.AddressWithPrefix
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
+ if !ndp.ep.hasPermanentAddressRLocked(generatedAddr.Address) {
+ break
+ }
+ }
+
+ // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
+ // address with a zero preferred lifetime. The checks above ensure this
+ // so we know the address is not deprecated.
+ addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaacTemp, false /* deprecated */)
+ if addressEndpoint == nil {
+ return false
+ }
+
+ state := tempSLAACAddrState{
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
+ }
+
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
+ }),
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr))
+ }
+
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
+ }),
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr))
+ }
+
+ // If an address has already been regenerated for this address, don't
+ // regenerate another address.
+ if tempAddrState.regenerated {
+ return
+ }
+
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */)
+ prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
+ ndp.slaacPrefixes[prefix] = prefixState
+ }),
+ createdAt: now,
+ addressEndpoint: addressEndpoint,
+ }
+
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
+
+ prefixState.generationAttempts++
+ prefixState.tempAddrs[generatedAddr.Address] = state
+
+ return true
+}
+
+// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix))
+ }
+
+ ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts)
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
+ deprecated := pl == 0
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.addressEndpoint)
+ } else {
+ prefixState.stableAddr.addressEndpoint.SetDeprecated(false)
+ }
+
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
+
+ now := time.Now()
+
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
+ if pl < header.NDPInfiniteLifetime {
+ if !deprecated {
+ prefixState.deprecationJob.Schedule(pl)
+ }
+ prefixState.preferredUntil = now.Add(pl)
+ } else {
+ prefixState.preferredUntil = time.Time{}
+ }
+
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the prefix to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
+
+ if vl >= header.NDPInfiniteLifetime {
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
+ prefixState.validUntil = time.Time{}
+ } else {
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the prefix was originally set to be valid forever, assume the
+ // remaining time to be the maximum possible value.
+ if prefixState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(prefixState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl > MinPrefixInformationValidLifetimeForUpdate {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ if effectiveVl != 0 {
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
+ }
+ }
+
+ // If DAD is not yet complete on the stable address, there is no need to do
+ // work with temporary addresses.
+ if prefixState.stableAddr.addressEndpoint.GetKind() != stack.Permanent {
+ return
+ }
+
+ // Note, we do not need to update the entries in the temporary address map
+ // after updating the jobs because the jobs are held as pointers.
+ var regenForAddr tcpip.Address
+ allAddressesRegenerated := true
+ for tempAddr, tempAddrState := range prefixState.tempAddrs {
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime. Note, the valid lifetime of a
+ // temporary address is relative to the address's creation time.
+ validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
+ if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
+ validUntil = prefixState.validUntil
+ }
+
+ // If the address is no longer valid, invalidate it immediately. Otherwise,
+ // reset the invalidation job.
+ newValidLifetime := validUntil.Sub(now)
+ if newValidLifetime <= 0 {
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
+ continue
+ }
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or
+ // the maximum temporary address preferred lifetime - the temporary address
+ // desync factor. Note, the preferred lifetime of a temporary address is
+ // relative to the address's creation time.
+ preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
+ if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
+ preferredUntil = prefixState.preferredUntil
+ }
+
+ // If the address is no longer preferred, deprecate it immediately.
+ // Otherwise, schedule the deprecation job again.
+ newPreferredLifetime := preferredUntil.Sub(now)
+ tempAddrState.deprecationJob.Cancel()
+ if newPreferredLifetime <= 0 {
+ ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
+ } else {
+ tempAddrState.addressEndpoint.SetDeprecated(false)
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
+ }
+
+ tempAddrState.regenJob.Cancel()
+ if tempAddrState.regenerated {
+ } else {
+ allAddressesRegenerated = false
+
+ if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration {
+ // The new preferred lifetime is less than the advance regeneration
+ // duration so regenerate an address for this temporary address
+ // immediately after we finish iterating over the temporary addresses.
+ regenForAddr = tempAddr
+ } else {
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ }
+ }
+ }
+
+ // Generate a new temporary address if all of the existing temporary addresses
+ // have been regenerated, or we need to immediately regenerate an address
+ // due to an update in preferred lifetime.
+ //
+ // If each temporay address has already been regenerated, no new temporary
+ // address is generated. To ensure continuation of temporary SLAAC addresses,
+ // we manually try to regenerate an address here.
+ if len(regenForAddr) != 0 || allAddressesRegenerated {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok {
+ state.regenerated = true
+ prefixState.tempAddrs[regenForAddr] = state
+ }
+ }
+}
+
+// deprecateSLAACAddress marks the address as deprecated and notifies the NDP
+// dispatcher that address has been deprecated.
+//
+// deprecateSLAACAddress does nothing if the address is already deprecated.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint) {
+ if addressEndpoint.Deprecated() {
+ return
+ }
+
+ addressEndpoint.SetDeprecated(true)
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
+ }
+}
+
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+
+ if addressEndpoint := state.stableAddr.addressEndpoint; addressEndpoint != nil {
+ // Since we are already invalidating the prefix, do not invalidate the
+ // prefix when removing the address.
+ if err := ndp.ep.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
+ }
+ }
+}
+
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
+// resources.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok || state.stableAddr.addressEndpoint == nil || addr.Address != state.stableAddr.addressEndpoint.AddressWithPrefix().Address {
+ return
+ }
+
+ if !invalidatePrefix {
+ // If the prefix is not being invalidated, disassociate the address from the
+ // prefix and do nothing further.
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
+//
+// Panics if the SLAAC prefix is not known.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ // Invalidate all temporary addresses.
+ for tempAddr, tempAddrState := range state.tempAddrs {
+ ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
+ }
+
+ if state.stableAddr.addressEndpoint != nil {
+ state.stableAddr.addressEndpoint.DecRef()
+ state.stableAddr.addressEndpoint = nil
+ }
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
+ delete(ndp.slaacPrefixes, prefix)
+}
+
+// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ // Since we are already invalidating the address, do not invalidate the
+ // address when removing the address.
+ if err := ndp.ep.removePermanentEndpointLocked(tempAddrState.addressEndpoint, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
+// SLAAC address's resources from ndp.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
+ if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
+ }
+
+ if !invalidateAddr {
+ return
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr))
+ }
+
+ tempAddrState, ok := state.tempAddrs[addr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
+// jobs and entry.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.addressEndpoint.DecRef()
+ tempAddrState.addressEndpoint = nil
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
+ delete(tempAddrs, tempAddr)
+}
+
+// removeSLAACAddresses removes all SLAAC addresses.
+//
+// If keepLinkLocal is false, the SLAAC generated link-local address is removed.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) {
+ linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
+ var linkLocalPrefixes int
+ for prefix, state := range ndp.slaacPrefixes {
+ // RFC 4862 section 5 states that routers are also expected to generate a
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if keepLinkLocal && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
+ continue
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
+
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
+ }
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state is cleaned up.
+//
+// This function invalidates all discovered on-link prefixes, discovered
+// routers, and auto-generated addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address aren't
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */)
+
+ for prefix := range ndp.onLinkPrefixes {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }
+
+ if got := len(ndp.onLinkPrefixes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
+ }
+
+ for router := range ndp.defaultRouters {
+ ndp.invalidateDefaultRouter(router)
+ }
+
+ if got := len(ndp.defaultRouters); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ }
+
+ ndp.dhcpv6Configuration = 0
+}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicit.timer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
+ ndp.ep.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the IPv6 endpoint lock and stopped
+ // solicitations. Simply return here and do nothing further.
+ ndp.ep.mu.Unlock()
+ return
+ }
+
+ // As per RFC 4861 section 4.1, the source of the RS is an address assigned
+ // to the sending interface, or the unspecified address if no address is
+ // assigned to the sending interface.
+ addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false)
+ if addressEndpoint == nil {
+ // Incase this ends up creating a new temporary address, we need to hold
+ // onto the endpoint until a route is obtained. If we decrement the
+ // reference count before obtaing a route, the address's resources would
+ // be released and attempting to obtain a route after would fail. Once a
+ // route is obtainted, it is safe to decrement the reference count since
+ // obtaining a route increments the address's reference count.
+ addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint)
+ }
+ ndp.ep.mu.Unlock()
+
+ localAddr := addressEndpoint.AddressWithPrefix().Address
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ addressEndpoint.DecRef()
+ if err != nil {
+ return
+ }
+ defer r.Release()
+
+ // Route should resolve immediately since
+ // header.IPv6AllRoutersMulticastAddress is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the IPv6 endpoint is
+ // not locked, the IPv6 endpoint could have been disabled or removed by
+ // another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
+ }
+
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ icmpData := header.ICMPv6(buffer.NewView(payloadSize))
+ icmpData.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs.Options().Serialize(optsSerializer)
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.ep.mu.Lock()
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the IPv6 endpoint
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ ndp.ep.mu.Unlock()
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The IPv6 endpoint that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicit.timer == nil {
+ // Nothing to do.
+ return
+ }
+
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+}
+
+// initializeTempAddrState initializes state related to temporary SLAAC
+// addresses.
+func (ndp *ndpState) initializeTempAddrState() {
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
+
+ if MaxDesyncFactor != 0 {
+ ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index c93d1194f..9033a9ed5 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"strings"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -65,10 +66,94 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+ t.Cleanup(ep.Close)
+
return s, ep
}
+var _ NDPDispatcher = (*testNDPDispatcher)(nil)
+
+// testNDPDispatcher is an NDPDispatcher only allows default router discovery.
+type testNDPDispatcher struct {
+ addr tcpip.Address
+}
+
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+ t.addr = addr
+ return true
+}
+
+func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+ t.addr = addr
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
+ return false
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) {
+}
+
+func (*testNDPDispatcher) OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) {
+}
+
+func (*testNDPDispatcher) OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) {
+}
+
+func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
+ var ndpDisp testNDPDispatcher
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ProtocolNumber, err)
+ }
+
+ ipv6EP := ep.(*endpoint)
+ ipv6EP.mu.Lock()
+ ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.Unlock()
+
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+
+ ndpDisp.addr = ""
+ ndpEP := ep.(stack.NDPEndpoint)
+ ndpEP.InvalidateDefaultRouter(lladdr1)
+ if ndpDisp.addr != lladdr1 {
+ t.Fatalf("got ndpDisp.addr = %s, want = %s", ndpDisp.addr, lladdr1)
+ }
+}
+
// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -325,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
naDst tcpip.Address
}{
{
- name: "Unspecified source to multicast destination",
+ name: "Unspecified source to solicited-node multicast destination",
nsOpts: nil,
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
@@ -352,11 +437,7 @@ func TestNeighorSolicitationResponse(t *testing.T) {
nsSrcLinkAddr: remoteLinkAddr0,
nsSrc: header.IPv6Any,
nsDst: nicAddr,
- nsInvalid: false,
- naDstLinkAddr: remoteLinkAddr0,
- naSolicited: false,
- naSrc: nicAddr,
- naDst: header.IPv6AllNodesMulticastAddress,
+ nsInvalid: true,
},
{
name: "Unspecified source with source ll option to unicast destination",
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index c9e57dc0d..d0ffc299a 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -8,6 +8,7 @@ go_library(
"testutil.go",
],
visibility = [
+ "//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
],