summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2020-02-06 17:07:04 -0800
committerKevin Krakauer <krakauer@google.com>2020-02-06 17:07:04 -0800
commitd98287f5eb40a9c91668b7511824c05d542e0599 (patch)
treef8430747db6e1c02fe0eb45a7dad0899d06bd072 /pkg
parentbf0ea204e9415a181c63ee10078cca753df14f7e (diff)
parent16561e461e82f8d846ef1f3ada990270ef39ccc6 (diff)
Merge branch 'master' into tcp-matchers-submit
Diffstat (limited to 'pkg')
-rw-r--r--pkg/metric/metric.go1
-rw-r--r--pkg/p9/BUILD3
-rw-r--r--pkg/p9/client.go9
-rw-r--r--pkg/pool/BUILD25
-rw-r--r--pkg/pool/pool.go (renamed from pkg/p9/pool.go)26
-rw-r--r--pkg/pool/pool_test.go (renamed from pkg/p9/pool_test.go)8
-rw-r--r--pkg/sentry/arch/arch_x86.go4
-rw-r--r--pkg/sentry/arch/signal_amd64.go2
-rw-r--r--pkg/sentry/fs/file_overlay_test.go1
-rw-r--r--pkg/sentry/fs/proc/README.md4
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/kernel.go3
-rw-r--r--pkg/sentry/kernel/kernel_opts.go20
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go5
-rw-r--r--pkg/sentry/socket/hostinet/sockopt_impl.go27
-rw-r--r--pkg/sentry/socket/netstack/netstack.go7
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go111
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go22
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go44
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go186
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go278
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/ndp_test.go82
-rw-r--r--pkg/tcpip/stack/nic.go84
-rw-r--r--pkg/tcpip/stack/nic_test.go62
-rw-r--r--pkg/tcpip/stack/stack_test.go115
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go9
30 files changed, 839 insertions, 326 deletions
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index 93d4f2b8c..006fcd9ab 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -46,7 +46,6 @@ var (
//
// TODO(b/67298402): Support non-cumulative metrics.
// TODO(b/67298427): Support metric fields.
-//
type Uint64Metric struct {
// value is the actual value of the metric. It must be accessed
// atomically.
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index 4ccc1de86..8904afad9 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -16,7 +16,6 @@ go_library(
"messages.go",
"p9.go",
"path_tree.go",
- "pool.go",
"server.go",
"transport.go",
"transport_flipcall.go",
@@ -27,6 +26,7 @@ go_library(
"//pkg/fdchannel",
"//pkg/flipcall",
"//pkg/log",
+ "//pkg/pool",
"//pkg/sync",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
@@ -41,7 +41,6 @@ go_test(
"client_test.go",
"messages_test.go",
"p9_test.go",
- "pool_test.go",
"transport_test.go",
"version_test.go",
],
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 4045e41fa..a6f493b82 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -22,6 +22,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -74,10 +75,10 @@ type Client struct {
socket *unet.Socket
// tagPool is the collection of available tags.
- tagPool pool
+ tagPool pool.Pool
// fidPool is the collection of available fids.
- fidPool pool
+ fidPool pool.Pool
// messageSize is the maximum total size of a message.
messageSize uint32
@@ -155,8 +156,8 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
}
c := &Client{
socket: socket,
- tagPool: pool{start: 1, limit: uint64(NoTag)},
- fidPool: pool{start: 1, limit: uint64(NoFID)},
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
pending: make(map[Tag]*response),
recvr: make(chan bool, 1),
messageSize: messageSize,
diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD
new file mode 100644
index 000000000..7b1c6b75b
--- /dev/null
+++ b/pkg/pool/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "pool",
+ srcs = [
+ "pool.go",
+ ],
+ deps = [
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "pool_test",
+ size = "small",
+ srcs = [
+ "pool_test.go",
+ ],
+ library = ":pool",
+)
diff --git a/pkg/p9/pool.go b/pkg/pool/pool.go
index 2b14a5ce3..a1b2e0cfe 100644
--- a/pkg/p9/pool.go
+++ b/pkg/pool/pool.go
@@ -12,33 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"gvisor.dev/gvisor/pkg/sync"
)
-// pool is a simple allocator.
-//
-// It is used for both tags and FIDs.
-type pool struct {
+// Pool is a simple allocator.
+type Pool struct {
mu sync.Mutex
// cache is the set of returned values.
cache []uint64
- // start is the starting value (if needed).
- start uint64
+ // Start is the starting value (if needed).
+ Start uint64
// max is the current maximum issued.
max uint64
- // limit is the upper limit.
- limit uint64
+ // Limit is the upper limit.
+ Limit uint64
}
// Get gets a value from the pool.
-func (p *pool) Get() (uint64, bool) {
+func (p *Pool) Get() (uint64, bool) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -50,18 +48,18 @@ func (p *pool) Get() (uint64, bool) {
}
// Over the limit?
- if p.start == p.limit {
+ if p.Start == p.Limit {
return 0, false
}
// Generate a new value.
- v := p.start
- p.start++
+ v := p.Start
+ p.Start++
return v, true
}
// Put returns a value to the pool.
-func (p *pool) Put(v uint64) {
+func (p *Pool) Put(v uint64) {
p.mu.Lock()
p.cache = append(p.cache, v)
p.mu.Unlock()
diff --git a/pkg/p9/pool_test.go b/pkg/pool/pool_test.go
index e4746b8da..d928439c1 100644
--- a/pkg/p9/pool_test.go
+++ b/pkg/pool/pool_test.go
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"testing"
)
func TestPoolUnique(t *testing.T) {
- p := pool{start: 1, limit: 3}
+ p := Pool{Start: 1, Limit: 3}
got := make(map[uint64]bool)
for {
@@ -39,7 +39,7 @@ func TestPoolUnique(t *testing.T) {
}
func TestExausted(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
for i := 0; i < 499; i++ {
_, ok := p.Get()
if !ok {
@@ -54,7 +54,7 @@ func TestExausted(t *testing.T) {
}
func TestPoolRecycle(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
n1, _ := p.Get()
p.Put(n1)
n2, _ := p.Get()
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index a18093155..3db8bd34b 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -114,6 +114,10 @@ func newX86FPStateSlice() []byte {
size, align := cpuid.HostFeatureSet().ExtendedStateSize()
capacity := size
// Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
if capacity < 4096 {
capacity = 4096
}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 81b92bb43..6fb756f0e 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -55,7 +55,7 @@ type SignalContext64 struct {
Trapno uint64
Oldmask linux.SignalSet
Cr2 uint64
- // Pointer to a struct _fpstate.
+ // Pointer to a struct _fpstate. See b/33003106#comment8.
Fpstate uint64
Reserved [8]uint64
}
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
index 02538bb4f..a76d87e3a 100644
--- a/pkg/sentry/fs/file_overlay_test.go
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -177,6 +177,7 @@ func TestReaddirRevalidation(t *testing.T) {
// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with
// a frozen dirent tree does not make Readdir calls to the underlying files.
+// This is a regression test for b/114808269.
func TestReaddirOverlayFrozen(t *testing.T) {
ctx := contexttest.Context(t)
diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md
index 5d4ec6c7b..6667a0916 100644
--- a/pkg/sentry/fs/proc/README.md
+++ b/pkg/sentry/fs/proc/README.md
@@ -11,6 +11,8 @@ inconsistency, please file a bug.
The following files are implemented:
+<!-- mdformat off(don't wrap the table) -->
+
| File /proc/ | Content |
| :------------------------ | :---------------------------------------------------- |
| [cpuinfo](#cpuinfo) | Info about the CPU |
@@ -22,6 +24,8 @@ The following files are implemented:
| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus |
| [version](#version) | Kernel version |
+<!-- mdformat on -->
+
### cpuinfo
```bash
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a27628c0a..2231d6973 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -91,6 +91,7 @@ go_library(
"fs_context.go",
"ipc_namespace.go",
"kernel.go",
+ "kernel_opts.go",
"kernel_state.go",
"pending_signals.go",
"pending_signals_list.go",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index dcd6e91c4..3ee760ba2 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -235,6 +235,9 @@ type Kernel struct {
// events. This is initialized lazily on the first unimplemented
// syscall.
unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"`
+
+ // SpecialOpts contains special kernel options.
+ SpecialOpts
}
// InitKernelArgs holds arguments to Init.
diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go
new file mode 100644
index 000000000..2e66ec587
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_opts.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// SpecialOpts contains non-standard options for the kernel.
+//
+// +stateify savable
+type SpecialOpts struct{}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 5a07d5d0e..023bad156 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -10,6 +10,7 @@ go_library(
"save_restore.go",
"socket.go",
"socket_unsafe.go",
+ "sockopt_impl.go",
"stack.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 34f63986f..de76388ac 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -285,7 +285,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
}
// Whitelist options and constrain option length.
- var optlen int
+ optlen := getSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
switch name {
@@ -330,7 +330,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Whitelist options and constrain option length.
- var optlen int
+ optlen := setSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
switch name {
@@ -353,6 +353,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
optlen = sizeofInt32
}
}
+
if optlen == 0 {
// Pretend to accept socket options we don't understand. This seems
// dangerous, but it's what netstack does...
diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go
new file mode 100644
index 000000000..8a783712e
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/sockopt_impl.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+func getSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
+
+func setSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 049d04bf2..ed2fbcceb 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -2229,11 +2229,16 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
var copied int
// Copy as many views as possible into the user-provided buffer.
- for dst.NumBytes() != 0 {
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
err = s.fetchReadView()
if err != nil {
break
}
+ if dst.NumBytes() == 0 {
+ break
+ }
var n int
var e error
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 711969b9b..6e0db2741 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
-// A Listener is a wrapper around a tcpip endpoint that implements
+// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
// net.Listener.
-type Listener struct {
+type TCPListener struct {
stack *stack.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancel chan struct{}
}
-// NewListener creates a new Listener.
-func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
- // Create TCP endpoint, bind it, then start listening.
+// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
+func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
+ return &TCPListener{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ cancel: make(chan struct{}),
+ }
+}
+
+// ListenTCP creates a new TCPListener.
+func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
+ // Create a TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
@@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr
}
}
- return &Listener{
- stack: s,
- ep: ep,
- wq: &wq,
- cancel: make(chan struct{}),
- }, nil
+ return NewTCPListener(s, &wq, ep), nil
}
// Close implements net.Listener.Close.
-func (l *Listener) Close() error {
+func (l *TCPListener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
-func (l *Listener) Shutdown() {
+func (l *TCPListener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
close(l.cancel) // broadcast cancellation
}
// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
+func (l *TCPListener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
@@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error {
return nil
}
-// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
// interface.
-type Conn struct {
+type TCPConn struct {
deadlineTimer
wq *waiter.Queue
@@ -228,9 +233,9 @@ type Conn struct {
read buffer.View
}
-// NewConn creates a new Conn.
-func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
- c := &Conn{
+// NewTCPConn creates a new TCPConn.
+func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
+ c := &TCPConn{
wq: wq,
ep: ep,
}
@@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
}
// Accept implements net.Conn.Accept.
-func (l *Listener) Accept() (net.Conn, error) {
+func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept()
if err == tcpip.ErrWouldBlock {
@@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) {
}
}
- return NewConn(wq, n), nil
+ return NewTCPConn(wq, n), nil
}
type opErrorer interface {
@@ -323,7 +328,7 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a
}
// Read implements net.Conn.Read.
-func (c *Conn) Read(b []byte) (int, error) {
+func (c *TCPConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
@@ -352,7 +357,7 @@ func (c *Conn) Read(b []byte) (int, error) {
}
// Write implements net.Conn.Write.
-func (c *Conn) Write(b []byte) (int, error) {
+func (c *TCPConn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
@@ -431,7 +436,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
// Close implements net.Conn.Close.
-func (c *Conn) Close() error {
+func (c *TCPConn) Close() error {
c.ep.Close()
return nil
}
@@ -440,7 +445,7 @@ func (c *Conn) Close() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
-func (c *Conn) CloseRead() error {
+func (c *TCPConn) CloseRead() error {
if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -451,7 +456,7 @@ func (c *Conn) CloseRead() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
-func (c *Conn) CloseWrite() error {
+func (c *TCPConn) CloseWrite() error {
if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -459,7 +464,7 @@ func (c *Conn) CloseWrite() error {
}
// LocalAddr implements net.Conn.LocalAddr.
-func (c *Conn) LocalAddr() net.Addr {
+func (c *TCPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
@@ -468,7 +473,7 @@ func (c *Conn) LocalAddr() net.Addr {
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *Conn) RemoteAddr() net.Addr {
+func (c *TCPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -476,7 +481,7 @@ func (c *Conn) RemoteAddr() net.Addr {
return fullToTCPAddr(a)
}
-func (c *Conn) newOpError(op string, err error) *net.OpError {
+func (c *TCPConn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
@@ -494,14 +499,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
-// DialTCP creates a new TCP Conn connected to the specified address.
-func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+// DialTCP creates a new TCPConn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCP Conn connected to the specified address
+// DialContextTCP creates a new TCPConn connected to the specified address
// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -543,12 +548,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
}
- return NewConn(&wq, ep), nil
+ return NewTCPConn(&wq, ep), nil
}
-// A PacketConn is a wrapper around a tcpip endpoint that implements
-// net.PacketConn.
-type PacketConn struct {
+// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
+// net.Conn and net.PacketConn.
+type UDPConn struct {
deadlineTimer
stack *stack.Stack
@@ -556,9 +561,9 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
- c := &PacketConn{
+// NewUDPConn creates a new UDPConn.
+func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
+ c := &UDPConn{
stack: s,
ep: ep,
wq: wq,
@@ -567,12 +572,12 @@ func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketC
return c
}
-// DialUDP creates a new PacketConn.
+// DialUDP creates a new UDPConn.
//
// If laddr is nil, a local address is automatically chosen.
//
-// If raddr is nil, the PacketConn is left unconnected.
-func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+// If raddr is nil, the UDPConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
@@ -591,7 +596,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := NewPacketConn(s, &wq, ep)
+ c := NewUDPConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -608,11 +613,11 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
return c, nil
}
-func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+func (c *UDPConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
-func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
@@ -623,7 +628,7 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *PacketConn) RemoteAddr() net.Addr {
+func (c *UDPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -632,13 +637,13 @@ func (c *PacketConn) RemoteAddr() net.Addr {
}
// Read implements net.Conn.Read
-func (c *PacketConn) Read(b []byte) (int, error) {
+func (c *UDPConn) Read(b []byte) (int, error) {
bytesRead, _, err := c.ReadFrom(b)
return bytesRead, err
}
// ReadFrom implements net.PacketConn.ReadFrom.
-func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
@@ -650,12 +655,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return copy(b, read), fullToUDPAddr(addr), nil
}
-func (c *PacketConn) Write(b []byte) (int, error) {
+func (c *UDPConn) Write(b []byte) (int, error) {
return c.WriteTo(b, nil)
}
// WriteTo implements net.PacketConn.WriteTo.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
@@ -713,13 +718,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// Close implements net.PacketConn.Close.
-func (c *PacketConn) Close() error {
+func (c *UDPConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
-func (c *PacketConn) LocalAddr() net.Addr {
+func (c *UDPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ee077ae83..ea0a0409a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -41,7 +41,7 @@ const (
)
func TestTimeouts(t *testing.T) {
- nc := NewConn(nil, nil)
+ nc := NewTCPConn(nil, nil)
dlfs := []struct {
name string
f func(time.Time) error
@@ -132,7 +132,7 @@ func TestCloseReader(t *testing.T) {
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -168,7 +168,7 @@ func TestCloseReader(t *testing.T) {
sender.close()
}
-// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
// using tcp.Forwarder.
func TestCloseReaderWithForwarder(t *testing.T) {
s, err := newLoopbackStack()
@@ -192,7 +192,7 @@ func TestCloseReaderWithForwarder(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
@@ -238,7 +238,7 @@ func TestCloseRead(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -257,7 +257,7 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseRead(); err != nil {
t.Errorf("c.CloseRead() = %v", err)
@@ -291,7 +291,7 @@ func TestCloseWrite(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
n, e := c.Read(make([]byte, 256))
if n != 0 || e != io.EOF {
@@ -309,7 +309,7 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseWrite(); err != nil {
t.Errorf("c.CloseWrite() = %v", err)
@@ -353,7 +353,7 @@ func TestUDPForwarder(t *testing.T) {
}
defer ep.Close()
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -396,7 +396,7 @@ func TestDeadlineChange(t *testing.T) {
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -541,7 +541,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
addr := tcpip.FullAddress{NICID, ip, 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 60817d36d..45dc757c7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "log"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -194,7 +196,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// TODO(b/148429853): Properly process the NS message and do Neighbor
// Unreachability Detection.
for {
- opt, done, _ := it.Next()
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
if done {
break
}
@@ -253,21 +259,25 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
}
na := header.NDPNeighborAdvert(h.NDPPayload())
+ it, err := na.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
targetAddr := na.TargetAddress()
stack := r.Stack()
rxNICID := r.NICID()
- isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
+ if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil {
// We will only get an error if rxNICID is unrecognized,
// which should not happen. For now short-circuit this
// packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
+ } else if isTentative {
// We just got an NA from a node that owns an address we
// are performing DAD on, implying the address is not
// unique. In this case we let the stack know so it can
@@ -283,13 +293,29 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// scenario is beyond the scope of RFC 4862. As such, we simply
// ignore such a scenario for now and proceed as normal.
//
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
+ //
// TODO(b/143147598): Handle the scenario described above. Also
// inform the netstack integration that a duplicate address was
// detected outside of DAD.
+ //
+ // TODO(b/148429853): Properly process the NA message and do Neighbor
+ // Unreachability Detection.
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress)
- if targetAddr != r.RemoteAddress {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
+ switch opt := opt.(type) {
+ case header.NDPTargetLinkLayerAddressOption:
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress())
+ }
}
case header.ICMPv6EchoRequest:
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d0e930e20..50c4b6474 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -121,21 +121,60 @@ func TestICMPCounts(t *testing.T) {
}
defer r.Release()
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
- typ header.ICMPv6Type
- size int
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
}{
- {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize},
- {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize},
- {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize},
- {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize},
- {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
- {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize},
- {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
- {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
- {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize},
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
}
handleIPv6Payload := func(hdr buffer.Prependable) {
@@ -154,10 +193,13 @@ func TestICMPCounts(t *testing.T) {
}
for _, typ := range types {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
handleIPv6Payload(hdr)
}
@@ -372,97 +414,104 @@ func TestLinkResolution(t *testing.T) {
}
func TestICMPChecksumValidationSimple(t *testing.T) {
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "DstUnreachable",
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.DstUnreachable
},
},
{
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "PacketTooBig",
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.PacketTooBig
},
},
{
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "TimeExceeded",
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.TimeExceeded
},
},
{
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "ParamProblem",
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.ParamProblem
},
},
{
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoRequest",
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoRequest
},
},
{
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoReply",
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoReply
},
},
{
- "RouterSolicit",
- header.ICMPv6RouterSolicit,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
},
{
- "RouterAdvert",
- header.ICMPv6RouterAdvert,
- header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterAdvert
},
},
{
- "NeighborSolicit",
- header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborSolicitMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborSolicit
},
},
{
- "NeighborAdvert",
- header.ICMPv6NeighborAdvert,
- header.ICMPv6NeighborAdvertSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborAdvert
},
},
{
- "RedirectMsg",
- header.ICMPv6RedirectMsg,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RedirectMsg
},
},
@@ -494,16 +543,19 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
)
}
- handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ handleIPv6Payload := func(checksum bool) {
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size),
+ PayloadLength: uint16(typ.size + extraDataLen),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
@@ -528,7 +580,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
// Without setting checksum, the incoming packet should
// be invalid.
- handleIPv6Payload(typ.typ, typ.size, false)
+ handleIPv6Payload(false)
if got := invalid.Value(); got != 1 {
t.Fatalf("got invalid = %d, want = 1", got)
}
@@ -538,7 +590,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
// When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, true)
+ handleIPv6Payload(true)
if got := typStat.Value(); got != 1 {
t.Fatalf("got %s = %d, want = 1", typ.name, got)
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index bd732f93f..c9395de52 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -70,76 +70,29 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
return s, ep
}
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an
-// NDP NS message with the Source Link Layer Address option results in a
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(0, 1280, linkAddr0)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
- if linkAddr != linkAddr1 {
- t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1)
- }
-}
-
-// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving
-// an NDP NS message with an invalid Source Link Layer Address option does not
-// result in a new entry in the link address cache for the sender of the
-// message.
-func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
- const nicID = 1
-
tests := []struct {
- name string
- optsBuf []byte
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
}{
{
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
name: "Too Small",
- optsBuf: []byte{1, 1, 1, 2, 3, 4, 5},
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
},
{
name: "Invalid Length",
- optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6},
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
},
}
@@ -186,20 +139,138 @@ func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
})
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
}
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option results in a
+// new entry in the link address cache for the target of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
}
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
}
- if linkAddr != "" {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr)
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
}
})
}
@@ -238,27 +309,59 @@ func TestHopLimitValidation(t *testing.T) {
})
}
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
- {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- }},
- {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- }},
- {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- }},
- {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- }},
- {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- }},
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
}
for _, typ := range types {
@@ -270,10 +373,13 @@ func TestHopLimitValidation(t *testing.T) {
invalid := stats.Invalid
typStat := typ.statCounter(stats)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f5b750046..705cf01ee 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -78,11 +78,15 @@ go_test(
go_test(
name = "stack_test",
size = "small",
- srcs = ["linkaddrcache_test.go"],
+ srcs = [
+ "linkaddrcache_test.go",
+ "nic_test.go",
+ ],
library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
],
)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 8af8565f7..1e575bdaf 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -478,13 +478,17 @@ func TestDADFail(t *testing.T) {
{
"RxAdvert",
func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(pkt.NDPPayload())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -1535,7 +1539,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
// Checks to see if list contains an IPv6 address, item.
-func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
protocolAddress := tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: item,
@@ -1661,7 +1665,7 @@ func TestAutoGenAddr(t *testing.T) {
// with non-zero lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -1677,10 +1681,10 @@ func TestAutoGenAddr(t *testing.T) {
// Receive an RA with prefix2 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
t.Fatalf("Should have %s in the list of addresses", addr2)
}
@@ -1701,10 +1705,10 @@ func TestAutoGenAddr(t *testing.T) {
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
t.Fatalf("Should have %s in the list of addresses", addr2)
}
}
@@ -1849,7 +1853,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Receive PI for prefix1.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
expectPrimaryAddr(addr1)
@@ -1857,7 +1861,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Deprecate addr for prefix1 immedaitely.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
// addr should still be the primary endpoint as there are no other addresses.
@@ -1875,7 +1879,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Receive PI for prefix2.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -1883,7 +1887,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Deprecate addr for prefix2 immedaitely.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr1 should be the primary endpoint now since addr2 is deprecated but
@@ -1978,7 +1982,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Receive PI for prefix2.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -1986,10 +1990,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Receive a PI for prefix1.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr1)
@@ -2005,10 +2009,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be deprecated.
expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr2 should be the primary endpoint now since addr1 is deprecated but
@@ -2045,10 +2049,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be deprecated.
expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr2 should be the primary endpoint now since it is not deprecated.
@@ -2059,10 +2063,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be invalidated.
expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -2108,10 +2112,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should not have %s in the list of addresses", addr2)
}
// Should not have any primary endpoints.
@@ -2596,7 +2600,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -2609,7 +2613,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
default:
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -2620,7 +2624,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
t.Fatal("unexpectedly received an auto gen addr event")
case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
}
@@ -2698,17 +2702,17 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
const validLifetimeSecondPrefix1 = 1
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
// Receive an RA with prefix2 in a PI with a large valid lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
@@ -2721,10 +2725,10 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
}
@@ -3010,16 +3014,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
nicinfo := s.NICInfo()
nic1Addrs := nicinfo[nicID1].ProtocolAddresses
nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !contains(nic1Addrs, e1Addr1) {
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
}
- if !contains(nic1Addrs, e1Addr2) {
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
}
- if !contains(nic2Addrs, e2Addr1) {
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
}
- if !contains(nic2Addrs, e2Addr2) {
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
@@ -3098,16 +3102,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
nicinfo = s.NICInfo()
nic1Addrs = nicinfo[nicID1].ProtocolAddresses
nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if contains(nic1Addrs, e1Addr1) {
+ if containsV6Addr(nic1Addrs, e1Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
}
- if contains(nic1Addrs, e1Addr2) {
+ if containsV6Addr(nic1Addrs, e1Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
}
- if contains(nic2Addrs, e2Addr1) {
+ if containsV6Addr(nic2Addrs, e2Addr1) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
}
- if contains(nic2Addrs, e2Addr2) {
+ if containsV6Addr(nic2Addrs, e2Addr2) {
t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7dad9a8cb..682e9c416 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,6 +16,7 @@ package stack
import (
"log"
+ "reflect"
"sort"
"strings"
"sync/atomic"
@@ -39,6 +40,7 @@ type NIC struct {
mu struct {
sync.RWMutex
+ enabled bool
spoofing bool
promiscuous bool
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
@@ -56,6 +58,14 @@ type NIC struct {
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
+
+ DisabledRx DirectionStats
+}
+
+func makeNICStats() NICStats {
+ var s NICStats
+ tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
}
// DirectionStats includes packet and byte counts.
@@ -99,16 +109,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
name: name,
linkEP: ep,
context: ctx,
- stats: NICStats{
- Tx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- Rx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- },
+ stats: makeNICStats(),
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
@@ -137,14 +138,30 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// enable enables the NIC. enable will attach the link to its LinkEndpoint and
// join the IPv6 All-Nodes Multicast address (ff02::1).
func (n *NIC) enable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if n.mu.enabled {
+ return nil
+ }
+
+ n.mu.enabled = true
+
n.attachLinkEndpoint()
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if err := n.AddAddress(tcpip.ProtocolAddress{
+ if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: header.IPv4ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint); err != nil {
+ }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -166,8 +183,22 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
- n.mu.Lock()
- defer n.mu.Unlock()
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the NIC was
+ // last enabled, other devices may have acquired the same addresses.
+ for _, r := range n.mu.endpoints {
+ addr := r.ep.ID().LocalAddress
+ if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
+ continue
+ }
+
+ r.setKind(permanentTentative)
+ if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ return err
+ }
+ }
if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
return err
@@ -633,7 +664,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
// If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process.
+ // mark it as tentative so it goes through the DAD process if the NIC is
+ // enabled. If the NIC is not enabled, DAD will be started when the NIC is
+ // enabled.
if isIPv6Unicast && kind == permanent {
kind = permanentTentative
}
@@ -668,8 +701,8 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.insertPrimaryEndpointLocked(ref, peb)
- // If we are adding a tentative IPv6 address, start DAD.
- if isIPv6Unicast && kind == permanentTentative {
+ // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
+ if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
@@ -700,9 +733,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- // TODO(b/140898488): Should tentative addresses be
- // returned?
+ case permanentExpired, temporary:
continue
}
addrs = append(addrs, tcpip.ProtocolAddress{
@@ -1016,11 +1047,23 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ // If the NIC is not yet enabled, don't receive any packets.
+ if !enabled {
+ n.mu.RUnlock()
+
+ n.stats.DisabledRx.Packets.Increment()
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ return
+ }
+
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
+ n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
@@ -1032,7 +1075,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
// Are any packet sockets listening for this network protocol?
- n.mu.RLock()
packetEPs := n.mu.packetEPs[protocol]
// Check whether there are packet sockets listening for every protocol.
// If we received a packet with protocol EthernetProtocolAll, then the
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
new file mode 100644
index 000000000..edaee3b86
--- /dev/null
+++ b/pkg/tcpip/stack/nic_test.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
+ // When the NIC is disabled, the only field that matters is the stats field.
+ // This test is limited to stats counter checks.
+ nic := NIC{
+ stats: makeNICStats(),
+ }
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 834fe9487..243868f3a 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -2561,3 +2561,118 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
})
}
}
+
+// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
+// was disabled have DAD performed on them when the NIC is enabled.
+func TestDoDADWhenNICEnabled(t *testing.T) {
+ t.Parallel()
+
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(dadTransmits, 1280, linkAddr1)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 128,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ // Address should be in the list of all addresses.
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should be tentative so it should not be a main address.
+ got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Enabling the NIC should start DAD for the address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != nicID {
+ t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID)
+ }
+ if e.addr != addr.AddressWithPrefix.Address {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", e.addr, addr.AddressWithPrefix.Address)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+
+ // Enabling the NIC again should be a no-op.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0fa141d58..0e944712f 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1124,6 +1124,10 @@ type ReadErrors struct {
// InvalidEndpointState is the number of times we found the endpoint state
// to be unexpected.
InvalidEndpointState StatCounter
+
+ // NotConnected is the number of times we tried to read but found that the
+ // endpoint was not connected.
+ NotConnected StatCounter
}
// WriteErrors collects packet write errors from an endpoint write call.
@@ -1166,7 +1170,9 @@ type TransportEndpointStats struct {
// marker interface.
func (*TransportEndpointStats) IsEndpointStats() {}
-func fillIn(v reflect.Value) {
+// InitStatCounters initializes v's fields with nil StatCounter fields to new
+// StatCounters.
+func InitStatCounters(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
if s, ok := v.Addr().Interface().(**StatCounter); ok {
@@ -1174,14 +1180,14 @@ func fillIn(v reflect.Value) {
*s = new(StatCounter)
}
} else {
- fillIn(v)
+ InitStatCounters(v)
}
}
}
// FillIn returns a copy of s with nil fields initialized to new StatCounters.
func (s Stats) FillIn() Stats {
- fillIn(reflect.ValueOf(&s).Elem())
+ InitStatCounters(reflect.ValueOf(&s).Elem())
return s
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b5a8e15ee..f2be0e651 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1003,8 +1003,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ e.stats.ReadErrors.NotConnected.Increment()
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
}
v, err := e.readLocked()
@@ -2166,6 +2166,9 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
e.isRegistered = true
e.setEndpointState(StateListen)
+ // The channel may be non-nil when we're restoring the endpoint, and it
+ // may be pre-populated with some previously accepted (but not Accepted)
+ // endpoints.
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 2c1505067..cc118c993 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5405,12 +5405,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- // Expect InvalidEndpointState errors on a read at this point.
- if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
+ t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
}
- if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
- t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
}
if err := ep.Listen(10); err != nil {