summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go85
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go18
-rw-r--r--pkg/tcpip/checker/checker.go4
-rw-r--r--pkg/tcpip/errors.go538
-rw-r--r--pkg/tcpip/faketime/BUILD5
-rw-r--r--pkg/tcpip/faketime/faketime.go318
-rw-r--r--pkg/tcpip/header/ipv4.go58
-rw-r--r--pkg/tcpip/header/ipv6.go4
-rw-r--r--pkg/tcpip/header/ipv6_test.go8
-rw-r--r--pkg/tcpip/link/channel/channel.go4
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go4
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go16
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go125
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go4
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go260
-rw-r--r--pkg/tcpip/link/loopback/loopback.go4
-rw-r--r--pkg/tcpip/link/muxed/injectable.go12
-rw-r--r--pkg/tcpip/link/nested/nested.go4
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go4
-rw-r--r--pkg/tcpip/link/pipe/pipe.go34
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go17
-rw-r--r--pkg/tcpip/link/rawfile/BUILD1
-rw-r--r--pkg/tcpip/link/rawfile/errors.go87
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go15
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go12
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go19
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go4
-rw-r--r--pkg/tcpip/link/tun/device.go4
-rw-r--r--pkg/tcpip/link/waitable/waitable.go4
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go4
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD18
-rw-r--r--pkg/tcpip/network/arp/arp.go146
-rw-r--r--pkg/tcpip/network/arp/arp_test.go210
-rw-r--r--pkg/tcpip/network/arp/stats.go70
-rw-r--r--pkg/tcpip/network/arp/stats_test.go93
-rw-r--r--pkg/tcpip/network/ip/BUILD5
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go4
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go4
-rw-r--r--pkg/tcpip/network/ip/stats.go100
-rw-r--r--pkg/tcpip/network/ip_test.go290
-rw-r--r--pkg/tcpip/network/ipv4/BUILD13
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go103
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go60
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go328
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go134
-rw-r--r--pkg/tcpip/network/ipv4/stats.go190
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go99
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go205
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go135
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go217
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go118
-rw-r--r--pkg/tcpip/network/ipv6/mld.go24
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go27
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go239
-rw-r--r--pkg/tcpip/network/ipv6/stats.go132
-rw-r--r--pkg/tcpip/network/testutil/BUILD2
-rw-r--r--pkg/tcpip/network/testutil/testutil.go76
-rw-r--r--pkg/tcpip/network/testutil/testutil_unsafe.go26
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go16
-rw-r--r--pkg/tcpip/ports/ports_test.go145
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go36
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go11
-rw-r--r--pkg/tcpip/socketops.go108
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go34
-rw-r--r--pkg/tcpip/stack/conntrack.go12
-rw-r--r--pkg/tcpip/stack/forwarding_test.go78
-rw-r--r--pkg/tcpip/stack/iptables.go30
-rw-r--r--pkg/tcpip/stack/iptables_types.go64
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go145
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go111
-rw-r--r--pkg/tcpip/stack/ndp_test.go186
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go13
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go313
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go13
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go342
-rw-r--r--pkg/tcpip/stack/nic.go173
-rw-r--r--pkg/tcpip/stack/nic_test.go37
-rw-r--r--pkg/tcpip/stack/nud.go4
-rw-r--r--pkg/tcpip/stack/nud_test.go15
-rw-r--r--pkg/tcpip/stack/pending_packets.go241
-rw-r--r--pkg/tcpip/stack/registration.go80
-rw-r--r--pkg/tcpip/stack/route.go124
-rw-r--r--pkg/tcpip/stack/stack.go264
-rw-r--r--pkg/tcpip/stack/stack_options.go32
-rw-r--r--pkg/tcpip/stack/stack_test.go419
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go28
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go8
-rw-r--r--pkg/tcpip/stack/transport_test.go89
-rw-r--r--pkg/tcpip/tcpip.go436
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go321
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go336
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go778
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go35
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go17
-rw-r--r--pkg/tcpip/tests/integration/route_test.go39
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go143
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go18
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go121
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go22
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go165
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go7
-rw-r--r--pkg/tcpip/transport/raw/protocol.go4
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go18
-rw-r--r--pkg/tcpip/transport/tcp/connect.go127
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go22
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go426
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go54
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go26
-rw-r--r--pkg/tcpip/transport/tcp/rack.go154
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go8
-rw-r--r--pkg/tcpip/transport/tcp/snd.go76
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go38
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go84
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go442
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go13
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go13
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go201
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go25
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go119
133 files changed, 7999 insertions, 4254 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index e7924e5c2..f979d22f0 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -18,6 +18,7 @@ go_template_instance(
go_library(
name = "tcpip",
srcs = [
+ "errors.go",
"sock_err_list.go",
"socketops.go",
"tcpip.go",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index fdeec12d3..c188aaa18 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -16,6 +16,7 @@
package gonet
import (
+ "bytes"
"context"
"errors"
"io"
@@ -247,7 +248,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
l.wq.EventRegister(&waitEntry, waiter.EventIn)
@@ -256,7 +257,7 @@ func (l *TCPListener) Accept() (net.Conn, error) {
for {
n, wq, err = l.ep.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
@@ -297,14 +298,14 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
res, err := ep.Read(&w, opts)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
res, err = ep.Read(&w, opts)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
select {
@@ -315,7 +316,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
}
}
- if err == tcpip.ErrClosedForReceive {
+ if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
return 0, io.EOF
}
@@ -354,10 +355,8 @@ func (c *TCPConn) Write(b []byte) (int, error) {
default:
}
- v := buffer.NewViewFromBytes(b)
-
// We must handle two soft failure conditions simultaneously:
- // 1. Write may write nothing and return tcpip.ErrWouldBlock.
+ // 1. Write may write nothing and return *tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
// not already and wait to try again.
// 2. Write may write fewer than the full number of bytes and return
@@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) {
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
- err *tcpip.Error
- nbytes int
- reg bool
- notifyCh chan struct{}
+ r bytes.Reader
+ nbytes int
+ entry waiter.Entry
+ ch <-chan struct{}
)
- for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
- if err == tcpip.ErrWouldBlock {
- if !reg {
- // Only register once.
- reg = true
-
- // Create wait queue entry that notifies a channel.
- var waitEntry waiter.Entry
- waitEntry, notifyCh = waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer c.wq.EventUnregister(&waitEntry)
+ for nbytes != len(b) {
+ r.Reset(b[nbytes:])
+ n, err := c.ep.Write(&r, tcpip.WriteOptions{})
+ nbytes += int(n)
+ switch err.(type) {
+ case nil:
+ case *tcpip.ErrWouldBlock:
+ if ch == nil {
+ entry, ch = waiter.NewChannelEntry(nil)
+
+ c.wq.EventRegister(&entry, waiter.EventOut)
+ defer c.wq.EventUnregister(&entry)
} else {
// Don't wait immediately after registration in case more data
// became available between when we last checked and when we setup
@@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) {
select {
case <-deadline:
return nbytes, c.newOpError("write", &timeoutError{})
- case <-notifyCh:
+ case <-ch:
+ continue
}
}
+ default:
+ return nbytes, c.newOpError("write", errors.New(err.String()))
}
-
- var n int64
- n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
- }
-
- if err == nil {
- return nbytes, nil
}
-
- return nbytes, c.newOpError("write", errors.New(err.String()))
+ return nbytes, nil
}
// Close implements net.Conn.Close.
@@ -502,7 +495,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
select {
case <-ctx.Done():
ep.Close()
@@ -644,17 +637,19 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// If we're being called by Write, there is no addr
- wopts := tcpip.WriteOptions{}
+ writeOptions := tcpip.WriteOptions{}
if addr != nil {
ua := addr.(*net.UDPAddr)
- wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+ writeOptions.To = &tcpip.FullAddress{
+ Addr: tcpip.Address(ua.IP),
+ Port: uint16(ua.Port),
+ }
}
- v := buffer.NewView(len(b))
- copy(v, b)
-
- n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
- if err == tcpip.ErrWouldBlock {
+ var r bytes.Reader
+ r.Reset(b)
+ n, err := c.ep.Write(&r, writeOptions)
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&waitEntry, waiter.EventOut)
@@ -666,8 +661,8 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
case <-notifyCh:
}
- n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
- if err != tcpip.ErrWouldBlock {
+ n, err = c.ep.Write(&r, writeOptions)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
break
}
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index b196324c7..2b3ea4bdf 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -58,7 +58,7 @@ func TestTimeouts(t *testing.T) {
}
}
-func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
+func newLoopbackStack() (*stack.Stack, tcpip.Error) {
// Create the stack and add a NIC.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
@@ -94,7 +94,7 @@ type testConnection struct {
ep tcpip.Endpoint
}
-func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
@@ -105,7 +105,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er
wq.EventRegister(&entry, waiter.EventOut)
err = ep.Connect(addr)
- if err == tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); ok {
<-ch
err = ep.LastError()
}
@@ -660,11 +660,13 @@ func TestTCPDialError(t *testing.T) {
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
- _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrNoRoute
- if !ok || got.Err.Error() != want.String() {
- t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
+ switch _, err := DialTCP(s, addr, ipv4.ProtocolNumber); err := err.(type) {
+ case *net.OpError:
+ if err.Err.Error() != (&tcpip.ErrNoRoute{}).String() {
+ t.Errorf("got DialTCP() = %s, want = %s", err, &tcpip.ErrNoRoute{})
+ }
+ default:
+ t.Errorf("got DialTCP(...) = %v, want %s", err, &tcpip.ErrNoRoute{})
}
}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 0ac2000ca..07b4393a4 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -234,7 +234,7 @@ func IPv4RouterAlert() NetworkChecker {
for {
opt, done, err := iterator.Next()
if err != nil {
- t.Fatalf("error acquiring next IPv4 option %s", err)
+ t.Fatalf("error acquiring next IPv4 option at offset %d", err.Pointer)
}
if done {
break
diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go
new file mode 100644
index 000000000..af46da1d2
--- /dev/null
+++ b/pkg/tcpip/errors.go
@@ -0,0 +1,538 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "fmt"
+)
+
+// Error represents an error in the netstack error space.
+//
+// The error interface is intentionally omitted to avoid loss of type
+// information that would occur if these errors were passed as error.
+type Error interface {
+ isError()
+
+ // IgnoreStats indicates whether this error should be included in failure
+ // counts in tcpip.Stats structs.
+ IgnoreStats() bool
+
+ fmt.Stringer
+}
+
+// ErrAborted indicates the operation was aborted.
+//
+// +stateify savable
+type ErrAborted struct{}
+
+func (*ErrAborted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAborted) IgnoreStats() bool {
+ return false
+}
+func (*ErrAborted) String() string {
+ return "operation aborted"
+}
+
+// ErrAddressFamilyNotSupported indicates the operation does not support the
+// given address family.
+//
+// +stateify savable
+type ErrAddressFamilyNotSupported struct{}
+
+func (*ErrAddressFamilyNotSupported) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAddressFamilyNotSupported) IgnoreStats() bool {
+ return false
+}
+func (*ErrAddressFamilyNotSupported) String() string {
+ return "address family not supported by protocol"
+}
+
+// ErrAlreadyBound indicates the endpoint is already bound.
+//
+// +stateify savable
+type ErrAlreadyBound struct{}
+
+func (*ErrAlreadyBound) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAlreadyBound) IgnoreStats() bool {
+ return true
+}
+func (*ErrAlreadyBound) String() string { return "endpoint already bound" }
+
+// ErrAlreadyConnected indicates the endpoint is already connected.
+//
+// +stateify savable
+type ErrAlreadyConnected struct{}
+
+func (*ErrAlreadyConnected) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAlreadyConnected) IgnoreStats() bool {
+ return true
+}
+func (*ErrAlreadyConnected) String() string { return "endpoint is already connected" }
+
+// ErrAlreadyConnecting indicates the endpoint is already connecting.
+//
+// +stateify savable
+type ErrAlreadyConnecting struct{}
+
+func (*ErrAlreadyConnecting) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrAlreadyConnecting) IgnoreStats() bool {
+ return true
+}
+func (*ErrAlreadyConnecting) String() string { return "endpoint is already connecting" }
+
+// ErrBadAddress indicates a bad address was provided.
+//
+// +stateify savable
+type ErrBadAddress struct{}
+
+func (*ErrBadAddress) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBadAddress) IgnoreStats() bool {
+ return false
+}
+func (*ErrBadAddress) String() string { return "bad address" }
+
+// ErrBadBuffer indicates a bad buffer was provided.
+//
+// +stateify savable
+type ErrBadBuffer struct{}
+
+func (*ErrBadBuffer) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBadBuffer) IgnoreStats() bool {
+ return false
+}
+func (*ErrBadBuffer) String() string { return "bad buffer" }
+
+// ErrBadLocalAddress indicates a bad local address was provided.
+//
+// +stateify savable
+type ErrBadLocalAddress struct{}
+
+func (*ErrBadLocalAddress) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBadLocalAddress) IgnoreStats() bool {
+ return false
+}
+func (*ErrBadLocalAddress) String() string { return "bad local address" }
+
+// ErrBroadcastDisabled indicates broadcast is not enabled on the endpoint.
+//
+// +stateify savable
+type ErrBroadcastDisabled struct{}
+
+func (*ErrBroadcastDisabled) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrBroadcastDisabled) IgnoreStats() bool {
+ return false
+}
+func (*ErrBroadcastDisabled) String() string { return "broadcast socket option disabled" }
+
+// ErrClosedForReceive indicates the endpoint is closed for incoming data.
+//
+// +stateify savable
+type ErrClosedForReceive struct{}
+
+func (*ErrClosedForReceive) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrClosedForReceive) IgnoreStats() bool {
+ return false
+}
+func (*ErrClosedForReceive) String() string { return "endpoint is closed for receive" }
+
+// ErrClosedForSend indicates the endpoint is closed for outgoing data.
+//
+// +stateify savable
+type ErrClosedForSend struct{}
+
+func (*ErrClosedForSend) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrClosedForSend) IgnoreStats() bool {
+ return false
+}
+func (*ErrClosedForSend) String() string { return "endpoint is closed for send" }
+
+// ErrConnectStarted indicates the endpoint is connecting asynchronously.
+//
+// +stateify savable
+type ErrConnectStarted struct{}
+
+func (*ErrConnectStarted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectStarted) IgnoreStats() bool {
+ return true
+}
+func (*ErrConnectStarted) String() string { return "connection attempt started" }
+
+// ErrConnectionAborted indicates the connection was aborted.
+//
+// +stateify savable
+type ErrConnectionAborted struct{}
+
+func (*ErrConnectionAborted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectionAborted) IgnoreStats() bool {
+ return false
+}
+func (*ErrConnectionAborted) String() string { return "connection aborted" }
+
+// ErrConnectionRefused indicates the connection was refused.
+//
+// +stateify savable
+type ErrConnectionRefused struct{}
+
+func (*ErrConnectionRefused) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectionRefused) IgnoreStats() bool {
+ return false
+}
+func (*ErrConnectionRefused) String() string { return "connection was refused" }
+
+// ErrConnectionReset indicates the connection was reset.
+//
+// +stateify savable
+type ErrConnectionReset struct{}
+
+func (*ErrConnectionReset) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrConnectionReset) IgnoreStats() bool {
+ return false
+}
+func (*ErrConnectionReset) String() string { return "connection reset by peer" }
+
+// ErrDestinationRequired indicates the operation requires a destination
+// address, and one was not provided.
+//
+// +stateify savable
+type ErrDestinationRequired struct{}
+
+func (*ErrDestinationRequired) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrDestinationRequired) IgnoreStats() bool {
+ return false
+}
+func (*ErrDestinationRequired) String() string { return "destination address is required" }
+
+// ErrDuplicateAddress indicates the operation encountered a duplicate address.
+//
+// +stateify savable
+type ErrDuplicateAddress struct{}
+
+func (*ErrDuplicateAddress) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrDuplicateAddress) IgnoreStats() bool {
+ return false
+}
+func (*ErrDuplicateAddress) String() string { return "duplicate address" }
+
+// ErrDuplicateNICID indicates the operation encountered a duplicate NIC ID.
+//
+// +stateify savable
+type ErrDuplicateNICID struct{}
+
+func (*ErrDuplicateNICID) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrDuplicateNICID) IgnoreStats() bool {
+ return false
+}
+func (*ErrDuplicateNICID) String() string { return "duplicate nic id" }
+
+// ErrInvalidEndpointState indicates the endpoint is in an invalid state.
+//
+// +stateify savable
+type ErrInvalidEndpointState struct{}
+
+func (*ErrInvalidEndpointState) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrInvalidEndpointState) IgnoreStats() bool {
+ return false
+}
+func (*ErrInvalidEndpointState) String() string { return "endpoint is in invalid state" }
+
+// ErrInvalidOptionValue indicates an invalid option value was provided.
+//
+// +stateify savable
+type ErrInvalidOptionValue struct{}
+
+func (*ErrInvalidOptionValue) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrInvalidOptionValue) IgnoreStats() bool {
+ return false
+}
+func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" }
+
+// ErrMalformedHeader indicates the operation encountered a malformed header.
+//
+// +stateify savable
+type ErrMalformedHeader struct{}
+
+func (*ErrMalformedHeader) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrMalformedHeader) IgnoreStats() bool {
+ return false
+}
+func (*ErrMalformedHeader) String() string { return "header is malformed" }
+
+// ErrMessageTooLong indicates the operation encountered a message whose length
+// exceeds the maximum permitted.
+//
+// +stateify savable
+type ErrMessageTooLong struct{}
+
+func (*ErrMessageTooLong) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrMessageTooLong) IgnoreStats() bool {
+ return false
+}
+func (*ErrMessageTooLong) String() string { return "message too long" }
+
+// ErrNetworkUnreachable indicates the operation is not able to reach the
+// destination network.
+//
+// +stateify savable
+type ErrNetworkUnreachable struct{}
+
+func (*ErrNetworkUnreachable) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNetworkUnreachable) IgnoreStats() bool {
+ return false
+}
+func (*ErrNetworkUnreachable) String() string { return "network is unreachable" }
+
+// ErrNoBufferSpace indicates no buffer space is available.
+//
+// +stateify savable
+type ErrNoBufferSpace struct{}
+
+func (*ErrNoBufferSpace) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoBufferSpace) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoBufferSpace) String() string { return "no buffer space available" }
+
+// ErrNoPortAvailable indicates no port could be allocated for the operation.
+//
+// +stateify savable
+type ErrNoPortAvailable struct{}
+
+func (*ErrNoPortAvailable) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoPortAvailable) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoPortAvailable) String() string { return "no ports are available" }
+
+// ErrNoRoute indicates the operation is not able to find a route to the
+// destination.
+//
+// +stateify savable
+type ErrNoRoute struct{}
+
+func (*ErrNoRoute) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoRoute) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoRoute) String() string { return "no route" }
+
+// ErrNoSuchFile is used to indicate that ENOENT should be returned the to
+// calling application.
+//
+// +stateify savable
+type ErrNoSuchFile struct{}
+
+func (*ErrNoSuchFile) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNoSuchFile) IgnoreStats() bool {
+ return false
+}
+func (*ErrNoSuchFile) String() string { return "no such file" }
+
+// ErrNotConnected indicates the endpoint is not connected.
+//
+// +stateify savable
+type ErrNotConnected struct{}
+
+func (*ErrNotConnected) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNotConnected) IgnoreStats() bool {
+ return false
+}
+func (*ErrNotConnected) String() string { return "endpoint not connected" }
+
+// ErrNotPermitted indicates the operation is not permitted.
+//
+// +stateify savable
+type ErrNotPermitted struct{}
+
+func (*ErrNotPermitted) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNotPermitted) IgnoreStats() bool {
+ return false
+}
+func (*ErrNotPermitted) String() string { return "operation not permitted" }
+
+// ErrNotSupported indicates the operation is not supported.
+//
+// +stateify savable
+type ErrNotSupported struct{}
+
+func (*ErrNotSupported) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrNotSupported) IgnoreStats() bool {
+ return false
+}
+func (*ErrNotSupported) String() string { return "operation not supported" }
+
+// ErrPortInUse indicates the provided port is in use.
+//
+// +stateify savable
+type ErrPortInUse struct{}
+
+func (*ErrPortInUse) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrPortInUse) IgnoreStats() bool {
+ return false
+}
+func (*ErrPortInUse) String() string { return "port is in use" }
+
+// ErrQueueSizeNotSupported indicates the endpoint does not allow queue size
+// operation.
+//
+// +stateify savable
+type ErrQueueSizeNotSupported struct{}
+
+func (*ErrQueueSizeNotSupported) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrQueueSizeNotSupported) IgnoreStats() bool {
+ return false
+}
+func (*ErrQueueSizeNotSupported) String() string { return "queue size querying not supported" }
+
+// ErrTimeout indicates the operation timed out.
+//
+// +stateify savable
+type ErrTimeout struct{}
+
+func (*ErrTimeout) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrTimeout) IgnoreStats() bool {
+ return false
+}
+func (*ErrTimeout) String() string { return "operation timed out" }
+
+// ErrUnknownDevice indicates an unknown device identifier was provided.
+//
+// +stateify savable
+type ErrUnknownDevice struct{}
+
+func (*ErrUnknownDevice) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownDevice) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownDevice) String() string { return "unknown device" }
+
+// ErrUnknownNICID indicates an unknown NIC ID was provided.
+//
+// +stateify savable
+type ErrUnknownNICID struct{}
+
+func (*ErrUnknownNICID) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownNICID) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownNICID) String() string { return "unknown nic id" }
+
+// ErrUnknownProtocol indicates an unknown protocol was requested.
+//
+// +stateify savable
+type ErrUnknownProtocol struct{}
+
+func (*ErrUnknownProtocol) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownProtocol) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownProtocol) String() string { return "unknown protocol" }
+
+// ErrUnknownProtocolOption indicates an unknown protocol option was provided.
+//
+// +stateify savable
+type ErrUnknownProtocolOption struct{}
+
+func (*ErrUnknownProtocolOption) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrUnknownProtocolOption) IgnoreStats() bool {
+ return false
+}
+func (*ErrUnknownProtocolOption) String() string { return "unknown option for protocol" }
+
+// ErrWouldBlock indicates the operation would block.
+//
+// +stateify savable
+type ErrWouldBlock struct{}
+
+func (*ErrWouldBlock) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrWouldBlock) IgnoreStats() bool {
+ return true
+}
+func (*ErrWouldBlock) String() string { return "operation would block" }
diff --git a/pkg/tcpip/faketime/BUILD b/pkg/tcpip/faketime/BUILD
index 114d43df3..bb9d44aff 100644
--- a/pkg/tcpip/faketime/BUILD
+++ b/pkg/tcpip/faketime/BUILD
@@ -6,10 +6,7 @@ go_library(
name = "faketime",
srcs = ["faketime.go"],
visibility = ["//visibility:public"],
- deps = [
- "//pkg/tcpip",
- "@com_github_dpjacques_clockwork//:go_default_library",
- ],
+ deps = ["//pkg/tcpip"],
)
go_test(
diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go
index f7a4fbde1..fb819d7a8 100644
--- a/pkg/tcpip/faketime/faketime.go
+++ b/pkg/tcpip/faketime/faketime.go
@@ -17,10 +17,10 @@ package faketime
import (
"container/heap"
+ "fmt"
"sync"
"time"
- "github.com/dpjacques/clockwork"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -44,38 +44,85 @@ func (*NullClock) AfterFunc(time.Duration, func()) tcpip.Timer {
return nil
}
+type notificationChannels struct {
+ mu struct {
+ sync.Mutex
+
+ ch []<-chan struct{}
+ }
+}
+
+func (n *notificationChannels) add(ch <-chan struct{}) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+ n.mu.ch = append(n.mu.ch, ch)
+}
+
+// wait returns once all the notification channels are readable.
+//
+// Channels that are added while waiting on existing channels will be waited on
+// as well.
+func (n *notificationChannels) wait() {
+ for {
+ n.mu.Lock()
+ ch := n.mu.ch
+ n.mu.ch = nil
+ n.mu.Unlock()
+
+ if len(ch) == 0 {
+ break
+ }
+
+ for _, c := range ch {
+ <-c
+ }
+ }
+}
+
// ManualClock implements tcpip.Clock and only advances manually with Advance
// method.
type ManualClock struct {
- clock clockwork.FakeClock
+ // runningTimers tracks the completion of timer callbacks that began running
+ // immediately upon their scheduling. It is used to ensure the proper ordering
+ // of timer callback dispatch.
+ runningTimers notificationChannels
+
+ mu struct {
+ sync.RWMutex
- // mu protects the fields below.
- mu sync.RWMutex
+ // now is the current (fake) time of the clock.
+ now time.Time
- // times is min-heap of times. A heap is used for quick retrieval of the next
- // upcoming time of scheduled work.
- times *timeHeap
+ // times is min-heap of times.
+ times timeHeap
- // waitGroups stores one WaitGroup for all work scheduled to execute at the
- // same time via AfterFunc. This allows parallel execution of all functions
- // passed to AfterFunc scheduled for the same time.
- waitGroups map[time.Time]*sync.WaitGroup
+ // timers holds the timers scheduled for each time.
+ timers map[time.Time]map[*manualTimer]struct{}
+ }
}
// NewManualClock creates a new ManualClock instance.
func NewManualClock() *ManualClock {
- return &ManualClock{
- clock: clockwork.NewFakeClock(),
- times: &timeHeap{},
- waitGroups: make(map[time.Time]*sync.WaitGroup),
- }
+ c := &ManualClock{}
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Set the initial time to a non-zero value since the zero value is used to
+ // detect inactive timers.
+ c.mu.now = time.Unix(0, 0)
+ c.mu.timers = make(map[time.Time]map[*manualTimer]struct{})
+
+ return c
}
var _ tcpip.Clock = (*ManualClock)(nil)
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
func (mc *ManualClock) NowNanoseconds() int64 {
- return mc.clock.Now().UnixNano()
+ mc.mu.RLock()
+ defer mc.mu.RUnlock()
+ return mc.mu.now.UnixNano()
}
// NowMonotonic implements tcpip.Clock.NowMonotonic.
@@ -85,128 +132,203 @@ func (mc *ManualClock) NowMonotonic() int64 {
// AfterFunc implements tcpip.Clock.AfterFunc.
func (mc *ManualClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
- until := mc.clock.Now().Add(d)
- wg := mc.addWait(until)
- return &manualTimer{
+ mt := &manualTimer{
clock: mc,
- until: until,
- timer: mc.clock.AfterFunc(d, func() {
- defer wg.Done()
- f()
- }),
+ f: f,
}
-}
-// addWait adds an additional wait to the WaitGroup for parallel execution of
-// all work scheduled for t. Returns a reference to the WaitGroup modified.
-func (mc *ManualClock) addWait(t time.Time) *sync.WaitGroup {
- mc.mu.RLock()
- wg, ok := mc.waitGroups[t]
- mc.mu.RUnlock()
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
- if ok {
- wg.Add(1)
- return wg
+ mc.resetTimerLocked(mt, d)
+ return mt
+}
+
+// resetTimerLocked schedules a timer to be fired after the given duration.
+//
+// Precondition: mc.mu and mt.mu must be locked.
+func (mc *ManualClock) resetTimerLocked(mt *manualTimer, d time.Duration) {
+ if !mt.mu.firesAt.IsZero() {
+ panic("tried to reset an active timer")
}
- mc.mu.Lock()
- heap.Push(mc.times, t)
- mc.mu.Unlock()
+ t := mc.mu.now.Add(d)
- wg = &sync.WaitGroup{}
- wg.Add(1)
+ if !mc.mu.now.Before(t) {
+ // If the timer is scheduled to fire immediately, call its callback
+ // in a new goroutine immediately.
+ //
+ // It needs to be called in its own goroutine to escape its current
+ // execution context - like an actual timer.
+ ch := make(chan struct{})
+ mc.runningTimers.add(ch)
- mc.mu.Lock()
- mc.waitGroups[t] = wg
- mc.mu.Unlock()
+ go func() {
+ defer close(ch)
+
+ mt.f()
+ }()
- return wg
+ return
+ }
+
+ mt.mu.firesAt = t
+
+ timers, ok := mc.mu.timers[t]
+ if !ok {
+ timers = make(map[*manualTimer]struct{})
+ mc.mu.timers[t] = timers
+ heap.Push(&mc.mu.times, t)
+ }
+
+ timers[mt] = struct{}{}
}
-// removeWait removes a wait from the WaitGroup for parallel execution of all
-// work scheduled for t.
-func (mc *ManualClock) removeWait(t time.Time) {
- mc.mu.RLock()
- defer mc.mu.RUnlock()
+// stopTimerLocked stops a timer from firing.
+//
+// Precondition: mc.mu and mt.mu must be locked.
+func (mc *ManualClock) stopTimerLocked(mt *manualTimer) {
+ t := mt.mu.firesAt
+ mt.mu.firesAt = time.Time{}
+
+ if t.IsZero() {
+ panic("tried to stop an inactive timer")
+ }
- wg := mc.waitGroups[t]
- wg.Done()
+ timers, ok := mc.mu.timers[t]
+ if !ok {
+ err := fmt.Sprintf("tried to stop an active timer but the clock does not have anything scheduled for the timer @ t = %s %p\nScheduled timers @:", t.UTC(), mt)
+ for t := range mc.mu.timers {
+ err += fmt.Sprintf("%s\n", t.UTC())
+ }
+ panic(err)
+ }
+
+ if _, ok := timers[mt]; !ok {
+ panic(fmt.Sprintf("did not have an entry in timers for an active timer @ t = %s", t.UTC()))
+ }
+
+ delete(timers, mt)
+
+ if len(timers) == 0 {
+ delete(mc.mu.timers, t)
+ }
}
// Advance executes all work that have been scheduled to execute within d from
-// the current time. Blocks until all work has completed execution.
+// the current time. Blocks until all work has completed execution.
func (mc *ManualClock) Advance(d time.Duration) {
- // Block until all the work is done
- until := mc.clock.Now().Add(d)
- for {
- mc.mu.Lock()
- if mc.times.Len() == 0 {
- mc.mu.Unlock()
- break
- }
+ // We spawn goroutines for timers that were scheduled to fire at the time of
+ // being reset. Wait for those goroutines to complete before proceeding so
+ // that timer callbacks are called in the right order.
+ mc.runningTimers.wait()
- t := heap.Pop(mc.times).(time.Time)
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ until := mc.mu.now.Add(d)
+ for mc.mu.times.Len() > 0 {
+ t := heap.Pop(&mc.mu.times).(time.Time)
if t.After(until) {
// No work to do
- heap.Push(mc.times, t)
- mc.mu.Unlock()
+ heap.Push(&mc.mu.times, t)
break
}
- mc.mu.Unlock()
- diff := t.Sub(mc.clock.Now())
- mc.clock.Advance(diff)
+ timers := mc.mu.timers[t]
+ delete(mc.mu.timers, t)
+
+ mc.mu.now = t
+
+ // Mark the timers as inactive since they will be fired.
+ //
+ // This needs to be done while holding mc's lock because we remove the entry
+ // in the map of timers for the current time. If an attempt to stop a
+ // timer is made after mc's lock was dropped but before the timer is
+ // marked inactive, we would panic since no entry exists for the time when
+ // the timer was expected to fire.
+ for mt := range timers {
+ mt.mu.Lock()
+ mt.mu.firesAt = time.Time{}
+ mt.mu.Unlock()
+ }
- mc.mu.RLock()
- wg := mc.waitGroups[t]
- mc.mu.RUnlock()
+ // Release the lock before calling the timer's callback fn since the
+ // callback fn might try to schedule a timer which requires obtaining
+ // mc's lock.
+ mc.mu.Unlock()
- wg.Wait()
+ for mt := range timers {
+ mt.f()
+ }
+ // The timer callbacks may have scheduled a timer to fire immediately.
+ // We spawn goroutines for these timers and need to wait for them to
+ // finish before proceeding so that timer callbacks are called in the
+ // right order.
+ mc.runningTimers.wait()
mc.mu.Lock()
- delete(mc.waitGroups, t)
- mc.mu.Unlock()
}
- if now := mc.clock.Now(); until.After(now) {
- mc.clock.Advance(until.Sub(now))
+
+ mc.mu.now = until
+}
+
+func (mc *ManualClock) resetTimer(mt *manualTimer, d time.Duration) {
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
+
+ if !mt.mu.firesAt.IsZero() {
+ mc.stopTimerLocked(mt)
}
+
+ mc.resetTimerLocked(mt, d)
+}
+
+func (mc *ManualClock) stopTimer(mt *manualTimer) bool {
+ mc.mu.Lock()
+ defer mc.mu.Unlock()
+
+ mt.mu.Lock()
+ defer mt.mu.Unlock()
+
+ if mt.mu.firesAt.IsZero() {
+ return false
+ }
+
+ mc.stopTimerLocked(mt)
+ return true
}
type manualTimer struct {
clock *ManualClock
- timer clockwork.Timer
+ f func()
- mu sync.RWMutex
- until time.Time
+ mu struct {
+ sync.Mutex
+
+ // firesAt is the time when the timer will fire.
+ //
+ // Zero only when the timer is not active.
+ firesAt time.Time
+ }
}
var _ tcpip.Timer = (*manualTimer)(nil)
// Reset implements tcpip.Timer.Reset.
-func (t *manualTimer) Reset(d time.Duration) {
- if !t.timer.Reset(d) {
- return
- }
-
- t.mu.Lock()
- defer t.mu.Unlock()
-
- t.clock.removeWait(t.until)
- t.until = t.clock.clock.Now().Add(d)
- t.clock.addWait(t.until)
+func (mt *manualTimer) Reset(d time.Duration) {
+ mt.clock.resetTimer(mt, d)
}
// Stop implements tcpip.Timer.Stop.
-func (t *manualTimer) Stop() bool {
- if !t.timer.Stop() {
- return false
- }
-
- t.mu.RLock()
- defer t.mu.RUnlock()
-
- t.clock.removeWait(t.until)
- return true
+func (mt *manualTimer) Stop() bool {
+ return mt.clock.stopTimer(mt)
}
type timeHeap []time.Time
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e6103f4bc..48ca60319 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@ package header
import (
"encoding/binary"
- "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -481,15 +480,13 @@ const (
IPv4OptionLengthOffset = 1
)
-// Potential errors when parsing generic IP options.
-var (
- ErrIPv4OptZeroLength = errors.New("zero length IP option")
- ErrIPv4OptDuplicate = errors.New("duplicate IP option")
- ErrIPv4OptInvalid = errors.New("invalid IP option")
- ErrIPv4OptMalformed = errors.New("malformed IP option")
- ErrIPv4OptionTruncated = errors.New("truncated IP option")
- ErrIPv4OptionAddress = errors.New("bad IP option address")
-)
+// IPv4OptParameterProblem indicates that a Parameter Problem message
+// should be generated, and gives the offset in the current entity
+// that should be used in that packet.
+type IPv4OptParameterProblem struct {
+ Pointer uint8
+ NeedICMP bool
+}
// IPv4Option is an interface representing various option types.
type IPv4Option interface {
@@ -583,8 +580,9 @@ func (i *IPv4OptionIterator) Finalize() IPv4Options {
// It returns
// - A slice of bytes holding the next option or nil if there is error.
// - A boolean which is true if parsing of all the options is complete.
-// - An error which is non-nil if an error condition was encountered.
-func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
+// Undefined in the case of error.
+// - An error indication which is non-nil if an error condition was found.
+func (i *IPv4OptionIterator) Next() (IPv4Option, bool, *IPv4OptParameterProblem) {
// The opts slice gets shorter as we process the options. When we have no
// bytes left we are done.
if len(i.options) == 0 {
@@ -606,24 +604,22 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
// There are no more single byte options defined. All the rest have a length
// field so we need to sanity check it.
if len(i.options) == 1 {
- return nil, true, ErrIPv4OptMalformed
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
optLen := i.options[IPv4OptionLengthOffset]
- if optLen == 0 {
- i.ErrCursor++
- return nil, true, ErrIPv4OptZeroLength
- }
+ if optLen <= IPv4OptionLengthOffset || optLen > uint8(len(i.options)) {
+ // The actual error is in the length (2nd byte of the option) but we
+ // return the start of the option for compatibility with Linux.
- if optLen == 1 {
- i.ErrCursor++
- return nil, true, ErrIPv4OptMalformed
- }
-
- if optLen > uint8(len(i.options)) {
- i.ErrCursor++
- return nil, true, ErrIPv4OptionTruncated
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
optionBody := i.options[:optLen]
@@ -635,7 +631,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
case IPv4OptionTimestampType:
if optLen < IPv4OptionTimestampHdrLength {
i.ErrCursor++
- return nil, true, ErrIPv4OptMalformed
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
retval := IPv4OptionTimestamp(optionBody)
return &retval, false, nil
@@ -643,7 +642,10 @@ func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
case IPv4OptionRecordRouteType:
if optLen < IPv4OptionRecordRouteHdrLength {
i.ErrCursor++
- return nil, true, ErrIPv4OptMalformed
+ return nil, false, &IPv4OptParameterProblem{
+ Pointer: i.ErrCursor,
+ NeedICMP: true,
+ }
}
retval := IPv4OptionRecordRoute(optionBody)
return &retval, false, nil
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 5580d6a78..f2403978c 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -453,9 +453,9 @@ const (
)
// ScopeForIPv6Address returns the scope for an IPv6 address.
-func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
+func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
if len(addr) != IPv6AddressSize {
- return GlobalScope, tcpip.ErrBadAddress
+ return GlobalScope, &tcpip.ErrBadAddress{}
}
switch {
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index e3fbd64f3..f10f446a6 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -299,7 +299,7 @@ func TestScopeForIPv6Address(t *testing.T) {
name string
addr tcpip.Address
scope header.IPv6AddressScope
- err *tcpip.Error
+ err tcpip.Error
}{
{
name: "Unique Local",
@@ -329,15 +329,15 @@ func TestScopeForIPv6Address(t *testing.T) {
name: "IPv4",
addr: "\x01\x02\x03\x04",
scope: header.GlobalScope,
- err: tcpip.ErrBadAddress,
+ err: &tcpip.ErrBadAddress{},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := header.ScopeForIPv6Address(test.addr)
- if err != test.err {
- t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
+ if diff := cmp.Diff(test.err, err); diff != "" {
+ t.Errorf("unexpected error from header.IsV6UniqueLocalAddress(%s), (-want, +got):\n%s", test.addr, diff)
}
if got != test.scope {
t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index a068d93a4..cd76272de 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -229,7 +229,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
@@ -243,7 +243,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
index 2f2d9d4ac..d873766a6 100644
--- a/pkg/tcpip/link/ethernet/ethernet.go
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -61,13 +61,13 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
return e.Endpoint.WritePacket(r, gso, proto, pkt)
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
linkAddr := e.Endpoint.LinkAddress()
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 10072eac1..ae1394ebf 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -35,7 +35,6 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/link/rawfile",
"//pkg/tcpip/stack",
"@com_github_google_go_cmp//cmp:go_default_library",
],
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index f86c383d8..0164d851b 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -57,7 +57,7 @@ import (
// linkDispatcher reads packets from the link FD and dispatches them to the
// NetworkDispatcher.
type linkDispatcher interface {
- dispatch() (bool, *tcpip.Error)
+ dispatch() (bool, tcpip.Error)
}
// PacketDispatchMode are the various supported methods of receiving and
@@ -118,7 +118,7 @@ type endpoint struct {
// closed is a function to be called when the FD's peer (if any) closes
// its end of the communication pipe.
- closed func(*tcpip.Error)
+ closed func(tcpip.Error)
inboundDispatchers []linkDispatcher
dispatcher stack.NetworkDispatcher
@@ -149,7 +149,7 @@ type Options struct {
// ClosedFunc is a function to be called when an endpoint's peer (if
// any) closes its end of the communication pipe.
- ClosedFunc func(*tcpip.Error)
+ ClosedFunc func(tcpip.Error)
// Address is the link address for this endpoint. Only used if
// EthernetHeader is true.
@@ -411,7 +411,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if e.hdrSize > 0 {
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
@@ -451,7 +451,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
-func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
+func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) {
// Send a batch of packets through batchFD.
mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
for _, pkt := range batch {
@@ -518,7 +518,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc
// - pkt.EgressRoute
// - pkt.GSOOptions
// - pkt.NetworkProtocolNumber
-func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
// Preallocate to avoid repeated reallocation as we append to batch.
// batchSz is 47 because when SWGSO is in use then a single 65KB TCP
// segment can get split into 46 segments of 1420 bytes and a single 216
@@ -562,13 +562,13 @@ func viewsEqual(vs1, vs2 []buffer.View) bool {
}
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
-func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error {
return rawfile.NonBlockingWrite(e.fds[0], packet)
}
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
-func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) *tcpip.Error {
+func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
for {
cont, err := inboundDispatcher.dispatch()
if err != nil || !cont {
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 90da22d34..e82371798 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -30,7 +30,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -96,7 +95,7 @@ func newContext(t *testing.T, opt *Options) *context {
}
done := make(chan struct{}, 2)
- opt.ClosedFunc = func(*tcpip.Error) {
+ opt.ClosedFunc = func(tcpip.Error) {
done <- struct{}{}
}
@@ -465,67 +464,85 @@ var capLengthTestCases = []struct {
config: []int{1, 2, 3},
n: 3,
wantUsed: 2,
- wantLengths: []int{1, 2, 3},
+ wantLengths: []int{1, 2},
},
}
-func TestReadVDispatcherCapLength(t *testing.T) {
+func TestIovecBuffer(t *testing.T) {
for _, c := range capLengthTestCases {
- // fd does not matter for this test.
- d := readVDispatcher{fd: -1, e: &endpoint{}}
- d.views = make([]buffer.View, len(c.config))
- d.iovecs = make([]syscall.Iovec, len(c.config))
- d.allocateViews(c.config)
-
- used := d.capViews(c.n, c.config)
- if used != c.wantUsed {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
- }
- lengths := make([]int, len(d.views))
- for i, v := range d.views {
- lengths[i] = len(v)
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
- }
- }
-}
+ t.Run(c.comment, func(t *testing.T) {
+ b := newIovecBuffer(c.config, false /* skipsVnetHdr */)
-func TestRecvMMsgDispatcherCapLength(t *testing.T) {
- for _, c := range capLengthTestCases {
- d := recvMMsgDispatcher{
- fd: -1, // fd does not matter for this test.
- e: &endpoint{},
- views: make([][]buffer.View, 1),
- iovecs: make([][]syscall.Iovec, 1),
- msgHdrs: make([]rawfile.MMsgHdr, 1),
- }
+ // Test initial allocation.
+ iovecs := b.nextIovecs()
+ if got, want := len(iovecs), len(c.config); got != want {
+ t.Fatalf("len(iovecs) = %d, want %d", got, want)
+ }
- for i := range d.views {
- d.views[i] = make([]buffer.View, len(c.config))
- }
- for i := range d.iovecs {
- d.iovecs[i] = make([]syscall.Iovec, len(c.config))
- }
- for k, msgHdr := range d.msgHdrs {
- msgHdr.Msg.Iov = &d.iovecs[k][0]
- msgHdr.Msg.Iovlen = uint64(len(c.config))
- }
+ // Make a copy as iovecs points to internal slice. We will need this state
+ // later.
+ oldIovecs := append([]syscall.Iovec(nil), iovecs...)
- d.allocateViews(c.config)
+ // Test the views that get pulled.
+ vv := b.pullViews(c.n)
+ var lengths []int
+ for _, v := range vv.Views() {
+ lengths = append(lengths, len(v))
+ }
+ if !reflect.DeepEqual(lengths, c.wantLengths) {
+ t.Errorf("Pulled view lengths = %v, want %v", lengths, c.wantLengths)
+ }
- used := d.capViews(0, c.n, c.config)
- if used != c.wantUsed {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
- }
- lengths := make([]int, len(d.views[0]))
- for i, v := range d.views[0] {
- lengths[i] = len(v)
- }
- if !reflect.DeepEqual(lengths, c.wantLengths) {
- t.Errorf("Test %q failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
- }
+ // Test that new views get reallocated.
+ for i, newIov := range b.nextIovecs() {
+ if i < c.wantUsed {
+ if newIov.Base == oldIovecs[i].Base {
+ t.Errorf("b.views[%d] should have been reallocated", i)
+ }
+ } else {
+ if newIov.Base != oldIovecs[i].Base {
+ t.Errorf("b.views[%d] should not have been reallocated", i)
+ }
+ }
+ }
+ })
+ }
+}
+func TestIovecBufferSkipVnetHdr(t *testing.T) {
+ for _, test := range []struct {
+ desc string
+ readN int
+ wantLen int
+ }{
+ {
+ desc: "nothing read",
+ readN: 0,
+ wantLen: 0,
+ },
+ {
+ desc: "smaller than vnet header",
+ readN: virtioNetHdrSize - 1,
+ wantLen: 0,
+ },
+ {
+ desc: "header skipped",
+ readN: virtioNetHdrSize + 100,
+ wantLen: 100,
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ b := newIovecBuffer([]int{10, 20, 50, 50}, true)
+ // Pretend a read happend.
+ b.nextIovecs()
+ vv := b.pullViews(test.readN)
+ if got, want := vv.Size(), test.wantLen; got != want {
+ t.Errorf("b.pullView(%d).Size() = %d; want %d", test.readN, got, want)
+ }
+ if got, want := len(vv.ToOwnedView()), test.wantLen; got != want {
+ t.Errorf("b.pullView(%d).ToOwnedView() has length %d; want %d", test.readN, got, want)
+ }
+ })
}
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index c475dda20..a2b63fe6b 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -129,7 +129,7 @@ type packetMMapDispatcher struct {
ringOffset int
}
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
for hdr.tpStatus()&tpStatusUser == 0 {
event := rawfile.PollEvent{
@@ -163,7 +163,7 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
// dispatch reads packets from an mmaped ring buffer and dispatches them to the
// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) {
pkt, err := d.readMMappedPacket()
if err != nil {
return false, err
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 8c3ca86d6..ecae1ad2d 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -29,92 +29,124 @@ import (
// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
-// readVDispatcher uses readv() system call to read inbound packets and
-// dispatches them.
-type readVDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
+type iovecBuffer struct {
// views are the actual buffers that hold the packet contents.
views []buffer.View
// iovecs are initialized with base pointers/len of the corresponding
- // entries in the views defined above, except when GSO is enabled then
- // the first iovec points to a buffer for the vnet header which is
- // stripped before the views are passed up the stack for further
+ // entries in the views defined above, except when GSO is enabled
+ // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header
+ // which is stripped before the views are passed up the stack for further
// processing.
iovecs []syscall.Iovec
+
+ // sizes is an array of buffer sizes for the underlying views. sizes is
+ // immutable.
+ sizes []int
+
+ // skipsVnetHdr is true if virtioNetHdr is to skipped.
+ skipsVnetHdr bool
}
-func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- d := &readVDispatcher{fd: fd, e: e}
- d.views = make([]buffer.View, len(BufConfig))
- iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- iovLen++
+func newIovecBuffer(sizes []int, skipsVnetHdr bool) *iovecBuffer {
+ b := &iovecBuffer{
+ views: make([]buffer.View, len(sizes)),
+ sizes: sizes,
+ skipsVnetHdr: skipsVnetHdr,
}
- d.iovecs = make([]syscall.Iovec, iovLen)
- return d, nil
+ niov := len(b.views)
+ if b.skipsVnetHdr {
+ niov++
+ }
+ b.iovecs = make([]syscall.Iovec, niov)
+ return b
}
-func (d *readVDispatcher) allocateViews(bufConfig []int) {
- var vnetHdr [virtioNetHdrSize]byte
+func (b *iovecBuffer) nextIovecs() []syscall.Iovec {
vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if b.skipsVnetHdr {
+ var vnetHdr [virtioNetHdrSize]byte
// The kernel adds virtioNetHdr before each packet, but
// we don't use it, so so we allocate a buffer for it,
// add it in iovecs but don't add it in a view.
- d.iovecs[0] = syscall.Iovec{
+ b.iovecs[0] = syscall.Iovec{
Base: &vnetHdr[0],
Len: uint64(virtioNetHdrSize),
}
vnetHdrOff++
}
- for i := 0; i < len(bufConfig); i++ {
- if d.views[i] != nil {
+ for i := range b.views {
+ if b.views[i] != nil {
break
}
- b := buffer.NewView(bufConfig[i])
- d.views[i] = b
- d.iovecs[i+vnetHdrOff] = syscall.Iovec{
- Base: &b[0],
- Len: uint64(len(b)),
+ v := buffer.NewView(b.sizes[i])
+ b.views[i] = v
+ b.iovecs[i+vnetHdrOff] = syscall.Iovec{
+ Base: &v[0],
+ Len: uint64(len(v)),
}
}
+ return b.iovecs
}
-func (d *readVDispatcher) capViews(n int, buffers []int) int {
+func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView {
+ var views []buffer.View
c := 0
- for i, s := range buffers {
- c += s
+ if b.skipsVnetHdr {
+ c += virtioNetHdrSize
if c >= n {
- d.views[i].CapLength(s - (c - n))
- return i + 1
+ // Nothing in the packet.
+ return buffer.NewVectorisedView(0, nil)
+ }
+ }
+ for i, v := range b.views {
+ c += len(v)
+ if c >= n {
+ b.views[i].CapLength(len(v) - (c - n))
+ views = append([]buffer.View(nil), b.views[:i+1]...)
+ break
}
}
- return len(buffers)
+ // Remove the first len(views) used views from the state.
+ for i := range views {
+ b.views[i] = nil
+ }
+ if b.skipsVnetHdr {
+ // Exclude the size of the vnet header.
+ n -= virtioNetHdrSize
+ }
+ return buffer.NewVectorisedView(n, views)
}
-// dispatch reads one packet from the file descriptor and dispatches it.
-func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
- d.allocateViews(BufConfig)
+// readVDispatcher uses readv() system call to read inbound packets and
+// dispatches them.
+type readVDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // buf is the iovec buffer that contains the packet contents.
+ buf *iovecBuffer
+}
+
+func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ d := &readVDispatcher{fd: fd, e: e}
+ skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
+ return d, nil
+}
- n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
+// dispatch reads one packet from the file descriptor and dispatches it.
+func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
+ n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs())
if n == 0 || err != nil {
return false, err
}
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- // Skip virtioNetHdr which is added before each packet, it
- // isn't used and it isn't in a view.
- n -= virtioNetHdrSize
- }
- used := d.capViews(n, BufConfig)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ Data: d.buf.pullViews(n),
})
var (
@@ -133,7 +165,12 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
} else {
// We don't get any indication of what the packet is, so try to guess
// if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(d.views[0]) {
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data.PullUp(1)
+ if !ok {
+ return true, nil
+ }
+ switch header.IPVersion(h) {
case header.IPv4Version:
p = header.IPv4ProtocolNumber
case header.IPv6Version:
@@ -145,11 +182,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
- // Prepare e.views for another packet: release used views.
- for i := 0; i < used; i++ {
- d.views[i] = nil
- }
-
return true, nil
}
@@ -162,15 +194,8 @@ type recvMMsgDispatcher struct {
// e is the endpoint this dispatcher is attached to.
e *endpoint
- // views is an array of array of buffers that contain packet contents.
- views [][]buffer.View
-
- // iovecs is an array of array of iovec records where each iovec base
- // pointer and length are initialzed to the corresponding view above,
- // except when GSO is enabled then the first iovec in each array of
- // iovecs points to a buffer for the vnet header which is stripped
- // before the views are passed up the stack for further processing.
- iovecs [][]syscall.Iovec
+ // bufs is an array of iovec buffers that contain packet contents.
+ bufs []*iovecBuffer
// msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
// reference an array of iovecs in the iovecs field defined above. This
@@ -187,74 +212,32 @@ const (
func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
d := &recvMMsgDispatcher{
- fd: fd,
- e: e,
- }
- d.views = make([][]buffer.View, MaxMsgsPerRecv)
- for i := range d.views {
- d.views[i] = make([]buffer.View, len(BufConfig))
- }
- d.iovecs = make([][]syscall.Iovec, MaxMsgsPerRecv)
- iovLen := len(BufConfig)
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- // virtioNetHdr is prepended before each packet.
- iovLen++
+ fd: fd,
+ e: e,
+ bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
+ msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv),
}
- for i := range d.iovecs {
- d.iovecs[i] = make([]syscall.Iovec, iovLen)
- }
- d.msgHdrs = make([]rawfile.MMsgHdr, MaxMsgsPerRecv)
- for i := range d.msgHdrs {
- d.msgHdrs[i].Msg.Iov = &d.iovecs[i][0]
- d.msgHdrs[i].Msg.Iovlen = uint64(iovLen)
+ skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0
+ for i := range d.bufs {
+ d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr)
}
return d, nil
}
-func (d *recvMMsgDispatcher) capViews(k, n int, buffers []int) int {
- c := 0
- for i, s := range buffers {
- c += s
- if c >= n {
- d.views[k][i].CapLength(s - (c - n))
- return i + 1
- }
- }
- return len(buffers)
-}
-
-func (d *recvMMsgDispatcher) allocateViews(bufConfig []int) {
- for k := 0; k < len(d.views); k++ {
- var vnetHdr [virtioNetHdrSize]byte
- vnetHdrOff := 0
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- // The kernel adds virtioNetHdr before each packet, but
- // we don't use it, so so we allocate a buffer for it,
- // add it in iovecs but don't add it in a view.
- d.iovecs[k][0] = syscall.Iovec{
- Base: &vnetHdr[0],
- Len: uint64(virtioNetHdrSize),
- }
- vnetHdrOff++
- }
- for i := 0; i < len(bufConfig); i++ {
- if d.views[k][i] != nil {
- break
- }
- b := buffer.NewView(bufConfig[i])
- d.views[k][i] = b
- d.iovecs[k][i+vnetHdrOff] = syscall.Iovec{
- Base: &b[0],
- Len: uint64(len(b)),
- }
- }
- }
-}
-
// recvMMsgDispatch reads more than one packet at a time from the file
// descriptor and dispatches it.
-func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
- d.allocateViews(BufConfig)
+func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
+ // Fill message headers.
+ for k := range d.msgHdrs {
+ if d.msgHdrs[k].Msg.Iovlen > 0 {
+ break
+ }
+ iovecs := d.bufs[k].nextIovecs()
+ iovLen := len(iovecs)
+ d.msgHdrs[k].Len = 0
+ d.msgHdrs[k].Msg.Iov = &iovecs[0]
+ d.msgHdrs[k].Msg.Iovlen = uint64(iovLen)
+ }
nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
if err != nil {
@@ -263,15 +246,14 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
// Process each of received packets.
for k := 0; k < nMsgs; k++ {
n := int(d.msgHdrs[k].Len)
- if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
- n -= virtioNetHdrSize
- }
- used := d.capViews(k, int(n), BufConfig)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ Data: d.bufs[k].pullViews(n),
})
+ // Mark that this iovec has been processed.
+ d.msgHdrs[k].Msg.Iovlen = 0
+
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
@@ -288,26 +270,24 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
} else {
// We don't get any indication of what the packet is, so try to guess
// if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(d.views[k][0]) {
+ // IP version information is at the first octet, so pulling up 1 byte.
+ h, ok := pkt.Data.PullUp(1)
+ if !ok {
+ // Skip this packet.
+ continue
+ }
+ switch header.IPVersion(h) {
case header.IPv4Version:
p = header.IPv4ProtocolNumber
case header.IPv6Version:
p = header.IPv6ProtocolNumber
default:
- return true, nil
+ // Skip this packet.
+ continue
}
}
d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
-
- // Prepare e.views for another packet: release used views.
- for i := 0; i < used; i++ {
- d.views[k][i] = nil
- }
- }
-
- for k := 0; k < nMsgs; k++ {
- d.msgHdrs[k].Len = 0
}
return true, nil
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index ac6a6be87..691467870 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,7 +76,7 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
// Construct data as the unparsed portion for the loopback packet.
data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
@@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 316f508e6..668f72eee 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -87,10 +87,10 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber,
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
return endpoint.WritePackets(r, gso, pkts, protocol)
}
@@ -98,19 +98,19 @@ func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkt
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
return endpoint.WritePacket(r, gso, protocol, pkt)
}
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
// InjectOutbound writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the dest address.
-func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error {
+func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error {
endpoint, ok := m.routes[dest]
if !ok {
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
return endpoint.InjectOutbound(dest, packet)
}
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
index 814a54f23..97ad9fdd5 100644
--- a/pkg/tcpip/link/nested/nested.go
+++ b/pkg/tcpip/link/nested/nested.go
@@ -113,12 +113,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return e.child.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return e.child.WritePackets(r, gso, pkts, protocol)
}
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
index c95cdd681..6cbe18a56 100644
--- a/pkg/tcpip/link/packetsocket/endpoint.go
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -35,13 +35,13 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index d6e83a414..bbe84f220 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -45,12 +45,7 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
}
-// WritePacket implements stack.LinkEndpoint.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
- if !e.linked.IsAttached() {
- return nil
- }
-
+func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkts stack.PacketBufferList) {
// Note that the local address from the perspective of this endpoint is the
// remote address from the perspective of the other end of the pipe
// (e.linked). Similarly, the remote address from the perspective of this
@@ -70,16 +65,33 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw
//
// TODO(gvisor.dev/issue/5289): don't use a new goroutine once we support send
// and receive queues.
- go e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- }))
+ go func() {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+ }
+ }()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
+ if e.linked.IsAttached() {
+ var pkts stack.PacketBufferList
+ pkts.PushBack(pkt)
+ e.deliverPackets(r, proto, pkts)
+ }
return nil
}
// WritePackets implements stack.LinkEndpoint.
-func (*Endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- panic("not implemented")
+func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ if e.linked.IsAttached() {
+ e.deliverPackets(r, proto, pkts)
+ }
+
+ return pkts.Len(), nil
}
// Attach implements stack.LinkEndpoint.
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 87035b034..128ef6e87 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -150,7 +150,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
}
// WritePacket implements stack.LinkEndpoint.WritePacket.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
// WritePacket caller's do not set the following fields in PacketBuffer
// so we populate them here.
pkt.EgressRoute = r
@@ -158,26 +158,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
if !d.q.enqueue(pkt) {
- return tcpip.ErrNoBufferSpace
+ return &tcpip.ErrNoBufferSpace{}
}
d.newPacketWaker.Assert()
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+//
+// Being a batch API, each packet in pkts should have the following
+// fields populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
enqueued := 0
for pkt := pkts.Front(); pkt != nil; {
- pkt.EgressRoute = r
- pkt.GSOOptions = gso
- pkt.NetworkProtocolNumber = protocol
d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
nxt := pkt.Next()
if !d.q.enqueue(pkt) {
if enqueued > 0 {
d.newPacketWaker.Assert()
}
- return enqueued, tcpip.ErrNoBufferSpace
+ return enqueued, &tcpip.ErrNoBufferSpace{}
}
pkt = nxt
enqueued++
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6c410c5a6..e1047da50 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -27,5 +27,6 @@ go_test(
library = "rawfile",
deps = [
"//pkg/tcpip",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index 604868fd8..406b97709 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -17,7 +17,6 @@
package rawfile
import (
- "fmt"
"syscall"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -25,48 +24,54 @@ import (
const maxErrno = 134
-var translations [maxErrno]*tcpip.Error
-
// TranslateErrno translate an errno from the syscall package into a
-// *tcpip.Error.
+// tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL).
-func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if e > 0 && e < syscall.Errno(len(translations)) {
- if err := translations[e]; err != nil {
- return err
- }
- }
- return tcpip.ErrInvalidEndpointState
-}
-
-func addTranslation(host syscall.Errno, trans *tcpip.Error) {
- if translations[host] != nil {
- panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
+// *tcpip.ErrInvalidEndpointState (EINVAL).
+func TranslateErrno(e syscall.Errno) tcpip.Error {
+ switch e {
+ case syscall.EEXIST:
+ return &tcpip.ErrDuplicateAddress{}
+ case syscall.ENETUNREACH:
+ return &tcpip.ErrNoRoute{}
+ case syscall.EINVAL:
+ return &tcpip.ErrInvalidEndpointState{}
+ case syscall.EALREADY:
+ return &tcpip.ErrAlreadyConnecting{}
+ case syscall.EISCONN:
+ return &tcpip.ErrAlreadyConnected{}
+ case syscall.EADDRINUSE:
+ return &tcpip.ErrPortInUse{}
+ case syscall.EADDRNOTAVAIL:
+ return &tcpip.ErrBadLocalAddress{}
+ case syscall.EPIPE:
+ return &tcpip.ErrClosedForSend{}
+ case syscall.EWOULDBLOCK:
+ return &tcpip.ErrWouldBlock{}
+ case syscall.ECONNREFUSED:
+ return &tcpip.ErrConnectionRefused{}
+ case syscall.ETIMEDOUT:
+ return &tcpip.ErrTimeout{}
+ case syscall.EINPROGRESS:
+ return &tcpip.ErrConnectStarted{}
+ case syscall.EDESTADDRREQ:
+ return &tcpip.ErrDestinationRequired{}
+ case syscall.ENOTSUP:
+ return &tcpip.ErrNotSupported{}
+ case syscall.ENOTTY:
+ return &tcpip.ErrQueueSizeNotSupported{}
+ case syscall.ENOTCONN:
+ return &tcpip.ErrNotConnected{}
+ case syscall.ECONNRESET:
+ return &tcpip.ErrConnectionReset{}
+ case syscall.ECONNABORTED:
+ return &tcpip.ErrConnectionAborted{}
+ case syscall.EMSGSIZE:
+ return &tcpip.ErrMessageTooLong{}
+ case syscall.ENOBUFS:
+ return &tcpip.ErrNoBufferSpace{}
+ default:
+ return &tcpip.ErrInvalidEndpointState{}
}
- translations[host] = trans
-}
-
-func init() {
- addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress)
- addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute)
- addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState)
- addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting)
- addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected)
- addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse)
- addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress)
- addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend)
- addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock)
- addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused)
- addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout)
- addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted)
- addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired)
- addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported)
- addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported)
- addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected)
- addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset)
- addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted)
- addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong)
- addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace)
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
index e4cdc66bd..61aea1744 100644
--- a/pkg/tcpip/link/rawfile/errors_test.go
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -20,34 +20,35 @@ import (
"syscall"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
func TestTranslateErrno(t *testing.T) {
for _, test := range []struct {
errno syscall.Errno
- translated *tcpip.Error
+ translated tcpip.Error
}{
{
errno: syscall.Errno(0),
- translated: tcpip.ErrInvalidEndpointState,
+ translated: &tcpip.ErrInvalidEndpointState{},
},
{
errno: syscall.Errno(maxErrno),
- translated: tcpip.ErrInvalidEndpointState,
+ translated: &tcpip.ErrInvalidEndpointState{},
},
{
errno: syscall.Errno(514),
- translated: tcpip.ErrInvalidEndpointState,
+ translated: &tcpip.ErrInvalidEndpointState{},
},
{
errno: syscall.EEXIST,
- translated: tcpip.ErrDuplicateAddress,
+ translated: &tcpip.ErrDuplicateAddress{},
},
} {
got := TranslateErrno(test.errno)
- if got != test.translated {
- t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ if diff := cmp.Diff(test.translated, got); diff != "" {
+ t.Errorf("unexpected result from TranslateErrno(%q), (-want, +got):\n%s", test.errno, diff)
}
}
}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index f4c32c2da..06f3ee21e 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -52,7 +52,7 @@ func GetMTU(name string) (uint32, error) {
// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
// partial data is written.
-func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
+func NonBlockingWrite(fd int, buf []byte) tcpip.Error {
var ptr unsafe.Pointer
if len(buf) > 0 {
ptr = unsafe.Pointer(&buf[0])
@@ -68,7 +68,7 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall.
// It fails if partial data is written.
-func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
+func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) tcpip.Error {
iovecLen := uintptr(len(iovec))
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
if e != 0 {
@@ -78,7 +78,7 @@ func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
}
// NonBlockingSendMMsg sends multiple messages on a socket.
-func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
n, _, e := syscall.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
if e != 0 {
return 0, TranslateErrno(e)
@@ -97,7 +97,7 @@ type PollEvent struct {
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
// descriptor becomes readable.
-func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
+func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
@@ -119,7 +119,7 @@ func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
// BlockingReadv reads from a file descriptor that is set up as non-blocking and
// stores the data in a list of iovecs buffers. If no data is available, it will
// block in a poll() syscall until the file descriptor becomes readable.
-func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
@@ -149,7 +149,7 @@ type MMsgHdr struct {
// and stores the received messages in a slice of MMsgHdr structures. If no data
// is available, it will block in a poll() syscall until the file descriptor
// becomes readable.
-func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, *tcpip.Error) {
+func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
for {
n, _, e := syscall.RawSyscall6(syscall.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), syscall.MSG_DONTWAIT, 0, 0)
if e == 0 {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 6c937c858..2599bc406 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
views := pkt.Views()
@@ -213,14 +213,14 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N
e.mu.Unlock()
if !ok {
- return tcpip.ErrWouldBlock
+ return &tcpip.ErrWouldBlock{}
}
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 23242b9e0..d480ad656 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -425,8 +425,9 @@ func TestFillTxQueue(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
@@ -493,8 +494,9 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buf.ToVectorisedView(),
})
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
@@ -538,8 +540,8 @@ func TestFillTxMemory(t *testing.T) {
Data: buf.ToVectorisedView(),
})
err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
- if want := tcpip.ErrWouldBlock; err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
@@ -579,8 +581,9 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
Data: buffer.NewView(bufferSize).ToVectorisedView(),
})
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
- t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{})
}
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 5859851d8..bd2b8d4bf 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -187,7 +187,7 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
e.dumpPacket(directionSend, gso, protocol, pkt)
return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
@@ -195,7 +195,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.dumpPacket(directionSend, gso, protocol, pkt)
}
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index bfac358f4..3829ca9c9 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -149,10 +149,10 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
Name: endpoint.name,
})
- switch err {
+ switch err.(type) {
case nil:
return endpoint, nil
- case tcpip.ErrDuplicateNICID:
+ case *tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 30f1ad540..20259b285 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -108,7 +108,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
@@ -121,7 +121,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
if !e.writeGate.Enter() {
return pkts.Len(), nil
}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index b139de7dd..e368a9eaa 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
e.writeCount += pkts.Len()
return pkts.Len(), nil
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9ebf31b78..0caa65251 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -25,5 +25,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 8a6bcfc2c..c7ab876bf 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -4,9 +4,13 @@ package(licenses = ["notice"])
go_library(
name = "arp",
- srcs = ["arp.go"],
+ srcs = [
+ "arp.go",
+ "stats.go",
+ ],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -33,3 +37,15 @@ go_test(
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
+
+go_test(
+ name = "stats_test",
+ size = "small",
+ srcs = ["stats_test.go"],
+ library = ":arp",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 3259d052f..7838cc753 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -19,8 +19,10 @@ package arp
import (
"fmt"
+ "reflect"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -50,11 +52,12 @@ type endpoint struct {
nic stack.NetworkInterface
linkAddrCache stack.LinkAddressCache
nud stack.NUDHandler
+ stats sharedStats
}
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
e.setEnabled(true)
@@ -98,10 +101,12 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.ARPSize
}
-func (*endpoint) Close() {}
+func (e *endpoint) Close() {
+ e.protocol.forgetEndpoint(e.nic.ID())
+}
-func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -110,51 +115,45 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) {
+ return 0, &tcpip.ErrNotSupported{}
}
-func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats().ARP
- stats.PacketsReceived.Increment()
+ stats := e.stats.arp
+ stats.packetsReceived.Increment()
if !e.isEnabled() {
- stats.DisabledPacketsReceived.Increment()
+ stats.disabledPacketsReceived.Increment()
return
}
h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
- stats.MalformedPacketsReceived.Increment()
+ stats.malformedPacketsReceived.Increment()
return
}
switch h.Op() {
case header.ARPRequest:
- stats.RequestsReceived.Increment()
+ stats.requestsReceived.Increment()
localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ stats.requestsReceivedUnknownTargetAddress.Increment()
+ return // we have no useful answer, ignore the request
+ }
+
+ remoteAddr := tcpip.Address(h.ProtocolAddressSender())
+ remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
- stats.RequestsReceivedUnknownTargetAddress.Increment()
- return // we have no useful answer, ignore the request
- }
-
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(remoteAddr, remoteLinkAddr)
} else {
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
- stats.RequestsReceivedUnknownTargetAddress.Increment()
- return // we have no useful answer, ignore the request
- }
-
- remoteAddr := tcpip.Address(h.ProtocolAddressSender())
- remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
@@ -186,18 +185,18 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Send the packet to the (new) target hardware address on the same
// hardware on which the request was received.
if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil {
- stats.OutgoingRepliesDropped.Increment()
+ stats.outgoingRepliesDropped.Increment()
} else {
- stats.OutgoingRepliesSent.Increment()
+ stats.outgoingRepliesSent.Increment()
}
case header.ARPReply:
- stats.RepliesReceived.Increment()
+ stats.repliesReceived.Increment()
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(addr, linkAddr)
return
}
@@ -216,9 +215,25 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
+var _ stack.NetworkProtocol = (*protocol)(nil)
+var _ stack.LinkAddressResolver = (*protocol)(nil)
+
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
stack *stack.Stack
+
+ mu struct {
+ sync.RWMutex
+
+ // eps is keyed by NICID to allow protocol methods to retrieve the correct
+ // endpoint depending on the NIC.
+ eps map[tcpip.NICID]*endpoint
+ }
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -236,39 +251,62 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
linkAddrCache: linkAddrCache,
nud: nud,
}
+
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+
+ stackStats := p.stack.Stats()
+ e.stats.arp.init(&e.stats.localStats.ARP, &stackStats.ARP)
+
+ p.mu.Lock()
+ p.mu.eps[nic.ID()] = e
+ p.mu.Unlock()
+
return e
}
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, nicID)
+}
+
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
- stats := p.stack.Stats().ARP
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
+ nicID := nic.ID()
+
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[nicID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
+ stats := netEP.stats.arp
if len(remoteLinkAddr) == 0 {
remoteLinkAddr = header.EthernetBroadcastAddress
}
- nicID := nic.ID()
if len(localAddr) == 0 {
- addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
- if err != nil {
- stats.OutgoingRequestInterfaceHasNoLocalAddressErrors.Increment()
- return err
+ addr, ok := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
}
if len(addr.Address) == 0 {
- stats.OutgoingRequestNetworkUnreachableErrors.Increment()
- return tcpip.ErrNetworkUnreachable
+ stats.outgoingRequestInterfaceHasNoLocalAddressErrors.Increment()
+ return &tcpip.ErrNetworkUnreachable{}
}
localAddr = addr.Address
} else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
- stats.OutgoingRequestBadLocalAddressErrors.Increment()
- return tcpip.ErrBadLocalAddress
+ stats.outgoingRequestBadLocalAddressErrors.Increment()
+ return &tcpip.ErrBadLocalAddress{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -288,10 +326,10 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
}
if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
- stats.OutgoingRequestsDropped.Increment()
+ stats.outgoingRequestsDropped.Increment()
return err
}
- stats.OutgoingRequestsSent.Increment()
+ stats.outgoingRequestsSent.Increment()
return nil
}
@@ -307,13 +345,13 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
// SetOption implements stack.NetworkProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.NetworkProtocol.Option.
-func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
@@ -329,5 +367,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
// NewProtocol returns an ARP network protocol.
func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
- return &protocol{stack: s}
+ return &protocol{
+ stack: s,
+ mu: struct {
+ sync.RWMutex
+ eps map[tcpip.NICID]*endpoint
+ }{eps: make(map[tcpip.NICID]*endpoint)},
+ }
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 0536e1698..b0f07aa44 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -125,8 +125,8 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo
func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
select {
case got := <-d.C:
- if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
- return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
+ if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
}
case <-ctx.Done():
return fmt.Errorf("%s for %s", ctx.Err(), want)
@@ -537,7 +537,7 @@ type testInterface struct {
nicID tcpip.NICID
- writeErr *tcpip.Error
+ writeErr tcpip.Error
}
func (t *testInterface) ID() tcpip.NICID {
@@ -560,15 +560,15 @@ func (*testInterface) Promiscuous() bool {
return false
}
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if t.writeErr != nil {
return t.writeErr
}
@@ -585,104 +585,122 @@ func TestLinkAddressRequest(t *testing.T) {
testAddr := tcpip.Address([]byte{1, 2, 3, 4})
tests := []struct {
- name string
- nicAddr tcpip.Address
- localAddr tcpip.Address
- remoteLinkAddr tcpip.LinkAddress
-
- linkErr *tcpip.Error
- expectedErr *tcpip.Error
- expectedLocalAddr tcpip.Address
- expectedRemoteLinkAddr tcpip.LinkAddress
- expectedRequestsSent uint64
- expectedRequestBadLocalAddressErrors uint64
- expectedRequestNetworkUnreachableErrors uint64
- expectedRequestDroppedErrors uint64
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+ linkErr tcpip.Error
+ expectedErr tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
+ expectedRequestsSent uint64
+ expectedRequestBadLocalAddressErrors uint64
+ expectedRequestInterfaceHasNoLocalAddressErrors uint64
+ expectedRequestDroppedErrors uint64
}{
{
- name: "Unicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Unicast with unspecified source",
- nicAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: remoteLinkAddr,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ localAddr: "",
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast with unspecified source",
- nicAddr: stackAddr,
- remoteLinkAddr: "",
- expectedLocalAddr: stackAddr,
- expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
- expectedRequestsSent: 1,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ localAddr: "",
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ expectedRequestsSent: 1,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Unicast with unassigned address",
- localAddr: testAddr,
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrBadLocalAddress,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Unicast with unassigned address",
+ nicAddr: stackAddr,
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 1,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast with unassigned address",
- localAddr: testAddr,
- remoteLinkAddr: "",
- expectedErr: tcpip.ErrBadLocalAddress,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 1,
- expectedRequestNetworkUnreachableErrors: 0,
+ name: "Multicast with unassigned address",
+ nicAddr: stackAddr,
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: &tcpip.ErrBadLocalAddress{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 1,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Unicast with no local address available",
- remoteLinkAddr: remoteLinkAddr,
- expectedErr: tcpip.ErrNetworkUnreachable,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 1,
+ name: "Unicast with no local address available",
+ nicAddr: "",
+ localAddr: "",
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 1,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Multicast with no local address available",
- remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
- expectedRequestsSent: 0,
- expectedRequestBadLocalAddressErrors: 0,
- expectedRequestNetworkUnreachableErrors: 1,
+ name: "Multicast with no local address available",
+ nicAddr: "",
+ localAddr: "",
+ remoteLinkAddr: "",
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 1,
+ expectedRequestDroppedErrors: 0,
},
{
- name: "Link error",
- nicAddr: stackAddr,
- localAddr: stackAddr,
- remoteLinkAddr: remoteLinkAddr,
- linkErr: tcpip.ErrInvalidEndpointState,
- expectedErr: tcpip.ErrInvalidEndpointState,
- expectedRequestDroppedErrors: 1,
+ name: "Link error",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ linkErr: &tcpip.ErrInvalidEndpointState{},
+ expectedErr: &tcpip.ErrInvalidEndpointState{},
+ expectedRequestsSent: 0,
+ expectedRequestBadLocalAddressErrors: 0,
+ expectedRequestInterfaceHasNoLocalAddressErrors: 0,
+ expectedRequestDroppedErrors: 1,
},
}
@@ -714,19 +732,20 @@ func TestLinkAddressRequest(t *testing.T) {
// link endpoint even though the stack uses the real NIC to validate the
// local address.
iface := testInterface{LinkEndpoint: linkEP, nicID: nicID, writeErr: test.linkErr}
- if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface); err != test.expectedErr {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &iface)
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", remoteAddr, test.localAddr, test.remoteLinkAddr, diff)
}
if got := s.Stats().ARP.OutgoingRequestsSent.Value(); got != test.expectedRequestsSent {
t.Errorf("got s.Stats().ARP.OutgoingRequestsSent.Value() = %d, want = %d", got, test.expectedRequestsSent)
}
+ if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != test.expectedRequestInterfaceHasNoLocalAddressErrors {
+ t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestInterfaceHasNoLocalAddressErrors)
+ }
if got := s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value(); got != test.expectedRequestBadLocalAddressErrors {
t.Errorf("got s.Stats().ARP.OutgoingRequestBadLocalAddressErrors.Value() = %d, want = %d", got, test.expectedRequestBadLocalAddressErrors)
}
- if got := s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value(); got != test.expectedRequestNetworkUnreachableErrors {
- t.Errorf("got s.Stats().ARP.OutgoingRequestNetworkUnreachableErrors.Value() = %d, want = %d", got, test.expectedRequestNetworkUnreachableErrors)
- }
if got := s.Stats().ARP.OutgoingRequestsDropped.Value(); got != test.expectedRequestDroppedErrors {
t.Errorf("got s.Stats().ARP.OutgoingRequestsDropped.Value() = %d, want = %d", got, test.expectedRequestDroppedErrors)
}
@@ -774,11 +793,8 @@ func TestLinkAddressRequestWithoutNIC(t *testing.T) {
t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
}
- if err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID}); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, tcpip.ErrUnknownNICID)
- }
-
- if got := s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value(); got != 1 {
- t.Errorf("got s.Stats().ARP.OutgoingRequestInterfaceHasNoLocalAddressErrors.Value() = %d, want = 1", got)
+ err := linkRes.LinkAddressRequest(remoteAddr, "", remoteLinkAddr, &testInterface{nicID: nicID})
+ if _, ok := err.(*tcpip.ErrNotConnected); !ok {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, "", remoteLinkAddr, err, &tcpip.ErrNotConnected{})
}
}
diff --git a/pkg/tcpip/network/arp/stats.go b/pkg/tcpip/network/arp/stats.go
new file mode 100644
index 000000000..6d7194c6c
--- /dev/null
+++ b/pkg/tcpip/network/arp/stats.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arp
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to ARP.
+type Stats struct {
+ // ARP holds ARP statistics.
+ ARP tcpip.ARPStats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*Stats) IsNetworkEndpointStats() {}
+
+type sharedStats struct {
+ localStats Stats
+ arp multiCounterARPStats
+}
+
+// LINT.IfChange(multiCounterARPStats)
+
+type multiCounterARPStats struct {
+ packetsReceived tcpip.MultiCounterStat
+ disabledPacketsReceived tcpip.MultiCounterStat
+ malformedPacketsReceived tcpip.MultiCounterStat
+ requestsReceived tcpip.MultiCounterStat
+ requestsReceivedUnknownTargetAddress tcpip.MultiCounterStat
+ outgoingRequestInterfaceHasNoLocalAddressErrors tcpip.MultiCounterStat
+ outgoingRequestBadLocalAddressErrors tcpip.MultiCounterStat
+ outgoingRequestsDropped tcpip.MultiCounterStat
+ outgoingRequestsSent tcpip.MultiCounterStat
+ repliesReceived tcpip.MultiCounterStat
+ outgoingRepliesDropped tcpip.MultiCounterStat
+ outgoingRepliesSent tcpip.MultiCounterStat
+}
+
+func (m *multiCounterARPStats) init(a, b *tcpip.ARPStats) {
+ m.packetsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.disabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
+ m.malformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived)
+ m.requestsReceived.Init(a.RequestsReceived, b.RequestsReceived)
+ m.requestsReceivedUnknownTargetAddress.Init(a.RequestsReceivedUnknownTargetAddress, b.RequestsReceivedUnknownTargetAddress)
+ m.outgoingRequestInterfaceHasNoLocalAddressErrors.Init(a.OutgoingRequestInterfaceHasNoLocalAddressErrors, b.OutgoingRequestInterfaceHasNoLocalAddressErrors)
+ m.outgoingRequestBadLocalAddressErrors.Init(a.OutgoingRequestBadLocalAddressErrors, b.OutgoingRequestBadLocalAddressErrors)
+ m.outgoingRequestsDropped.Init(a.OutgoingRequestsDropped, b.OutgoingRequestsDropped)
+ m.outgoingRequestsSent.Init(a.OutgoingRequestsSent, b.OutgoingRequestsSent)
+ m.repliesReceived.Init(a.RepliesReceived, b.RepliesReceived)
+ m.outgoingRepliesDropped.Init(a.OutgoingRepliesDropped, b.OutgoingRepliesDropped)
+ m.outgoingRepliesSent.Init(a.OutgoingRepliesSent, b.OutgoingRepliesSent)
+}
+
+// LINT.ThenChange(../../tcpip.go:ARPStats)
diff --git a/pkg/tcpip/network/arp/stats_test.go b/pkg/tcpip/network/arp/stats_test.go
new file mode 100644
index 000000000..036fdf739
--- /dev/null
+++ b/pkg/tcpip/network/arp/stats_test.go
@@ -0,0 +1,93 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arp
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkInterface
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func knownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ nic := testInterface{nicID: 1}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ var nicIDs []tcpip.NICID
+
+ proto.mu.Lock()
+ foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+
+ if !hasEndpointBeforeClose {
+ t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
+ }
+
+ ep.Close()
+
+ proto.mu.Lock()
+ _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEndpointAfterClose {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // expected to be bound by a MultiCounterStat.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.arp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ARP).Elem(), reflect.ValueOf(&refStack.ARP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
index ca1247c1e..411bca25d 100644
--- a/pkg/tcpip/network/ip/BUILD
+++ b/pkg/tcpip/network/ip/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "ip",
- srcs = ["generic_multicast_protocol.go"],
+ srcs = [
+ "generic_multicast_protocol.go",
+ "stats.go",
+ ],
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index f2f0e069c..a81f5c8c3 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -174,10 +174,10 @@ type MulticastGroupProtocol interface {
//
// Returns false if the caller should queue the report to be sent later. Note,
// returning false does not mean that the receiver hit an error.
- SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error)
+ SendReport(groupAddress tcpip.Address) (sent bool, err tcpip.Error)
// SendLeave sends a multicast leave for the specified group address.
- SendLeave(groupAddress tcpip.Address) *tcpip.Error
+ SendLeave(groupAddress tcpip.Address) tcpip.Error
}
// GenericMulticastProtocolState is the per interface generic multicast protocol
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 85593f211..d5d5a449e 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -141,7 +141,7 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
@@ -158,7 +158,7 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: m.mu must be locked.
-func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
m.mu.Unlock()
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/ip/stats.go
new file mode 100644
index 000000000..898f8b356
--- /dev/null
+++ b/pkg/tcpip/network/ip/stats.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ip
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+// LINT.IfChange(MultiCounterIPStats)
+
+// MultiCounterIPStats holds IP statistics, each counter may have several
+// versions.
+type MultiCounterIPStats struct {
+ // PacketsReceived is the total number of IP packets received from the link
+ // layer.
+ PacketsReceived tcpip.MultiCounterStat
+
+ // DisabledPacketsReceived is the total number of IP packets received from the
+ // link layer when the IP layer is disabled.
+ DisabledPacketsReceived tcpip.MultiCounterStat
+
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived tcpip.MultiCounterStat
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived tcpip.MultiCounterStat
+
+ // PacketsDelivered is the total number of incoming IP packets that are
+ // successfully delivered to the transport layer.
+ PacketsDelivered tcpip.MultiCounterStat
+
+ // PacketsSent is the total number of IP packets sent via WritePacket.
+ PacketsSent tcpip.MultiCounterStat
+
+ // OutgoingPacketErrors is the total number of IP packets which failed to
+ // write to a link-layer endpoint.
+ OutgoingPacketErrors tcpip.MultiCounterStat
+
+ // MalformedPacketsReceived is the total number of IP Packets that were
+ // dropped due to the IP packet header failing validation checks.
+ MalformedPacketsReceived tcpip.MultiCounterStat
+
+ // MalformedFragmentsReceived is the total number of IP Fragments that were
+ // dropped due to the fragment failing validation checks.
+ MalformedFragmentsReceived tcpip.MultiCounterStat
+
+ // IPTablesPreroutingDropped is the total number of IP packets dropped in the
+ // Prerouting chain.
+ IPTablesPreroutingDropped tcpip.MultiCounterStat
+
+ // IPTablesInputDropped is the total number of IP packets dropped in the Input
+ // chain.
+ IPTablesInputDropped tcpip.MultiCounterStat
+
+ // IPTablesOutputDropped is the total number of IP packets dropped in the
+ // Output chain.
+ IPTablesOutputDropped tcpip.MultiCounterStat
+
+ // OptionTSReceived is the number of Timestamp options seen.
+ OptionTSReceived tcpip.MultiCounterStat
+
+ // OptionRRReceived is the number of Record Route options seen.
+ OptionRRReceived tcpip.MultiCounterStat
+
+ // OptionUnknownReceived is the number of unknown IP options seen.
+ OptionUnknownReceived tcpip.MultiCounterStat
+}
+
+// Init sets internal counters to track a and b counters.
+func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
+ m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived)
+ m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived)
+ m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived)
+ m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived)
+ m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered)
+ m.PacketsSent.Init(a.PacketsSent, b.PacketsSent)
+ m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors)
+ m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived)
+ m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived)
+ m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
+ m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
+ m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
+ m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived)
+ m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived)
+ m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived)
+}
+
+// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 3005973d7..47cce79bb 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -18,6 +18,7 @@ import (
"strings"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -167,7 +168,7 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
@@ -189,7 +190,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
panic("not implemented")
}
@@ -203,7 +204,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net
panic("not implemented")
}
-func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
+func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -219,7 +220,7 @@ func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
-func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
+func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
@@ -235,14 +236,14 @@ func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) {
return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+func buildDummyStackWithLinkEndpoint(t *testing.T, mtu uint32) (*stack.Stack, *channel.Endpoint) {
t.Helper()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
})
- e := channel.New(0, 1280, "")
+ e := channel.New(1, mtu, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -263,7 +264,7 @@ func buildDummyStackWithLinkEndpoint(t *testing.T) (*stack.Stack, *channel.Endpo
func buildDummyStack(t *testing.T) *stack.Stack {
t.Helper()
- s, _ := buildDummyStackWithLinkEndpoint(t)
+ s, _ := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
return s
}
@@ -306,8 +307,8 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
-func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func TestSourceAddressValidation(t *testing.T) {
@@ -416,7 +417,7 @@ func TestSourceAddressValidation(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s, e := buildDummyStackWithLinkEndpoint(t)
+ s, e := buildDummyStackWithLinkEndpoint(t, header.IPv6MinimumMTU)
test.rxICMP(e, test.srcAddress)
var wantValid uint64
@@ -479,8 +480,9 @@ func TestEnableWhenNICDisabled(t *testing.T) {
// Attempting to enable the endpoint while the NIC is disabled should
// fail.
nic.setEnabled(false)
- if err := ep.Enable(); err != tcpip.ErrNotPermitted {
- t.Fatalf("got ep.Enable() = %s, want = %s", err, tcpip.ErrNotPermitted)
+ err := ep.Enable()
+ if _, ok := err.(*tcpip.ErrNotPermitted); !ok {
+ t.Fatalf("got ep.Enable() = %s, want = %s", err, &tcpip.ErrNotPermitted{})
}
// ep should consider the NIC's enabled status when determining its own
// enabled status so we "enable" the NIC to read just the endpoint's
@@ -1122,7 +1124,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
remoteAddr tcpip.Address
pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
- expectedErr *tcpip.Error
+ expectedErr tcpip.Error
}{
{
name: "IPv4",
@@ -1187,7 +1189,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
ip.SetHeaderLength(header.IPv4MinimumSize - 1)
return hdr.View().ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
{
name: "IPv4 too small",
@@ -1205,7 +1207,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
{
name: "IPv4 minimum size",
@@ -1465,7 +1467,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
- expectedErr: tcpip.ErrMalformedHeader,
+ expectedErr: &tcpip.ErrMalformedHeader{},
},
}
@@ -1490,7 +1492,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
})
- e := channel.New(1, 1280, "")
+ e := channel.New(1, header.IPv6MinimumMTU, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
}
@@ -1506,10 +1508,13 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
}
defer r.Release()
- if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr),
- })); err != test.expectedErr {
- t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ {
+ err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr),
+ }))
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from r.WriteHeaderIncludedPacket(_), (-want, +got):\n%s", diff)
+ }
}
if test.expectedErr != nil {
@@ -1526,3 +1531,246 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
})
}
}
+
+// Test that the included data in an ICMP error packet conforms to the
+// requirements of RFC 972, RFC 4443 section 2.4 and RFC 1812 Section 4.3.2.3
+func TestICMPInclusionSize(t *testing.T) {
+ const (
+ replyHeaderLength4 = header.IPv4MinimumSize + header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ replyHeaderLength6 = header.IPv6MinimumSize + header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ targetSize4 = header.IPv4MinimumProcessableDatagramSize
+ targetSize6 = header.IPv6MinimumMTU
+ // A protocol number that will cause an error response.
+ reservedProtocol = 254
+ )
+
+ // IPv4 function to create a IP packet and send it to the stack.
+ // The packet should generate an error response. We can do that by using an
+ // unknown transport protocol (254).
+ rxIPv4Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(payload)
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize)
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: reservedProtocol,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv4Addr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(buffer.View(payload))
+ // Take a copy before InjectInbound takes ownership of vv
+ // as vv may be changed during the call.
+ v := vv.ToView()
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ return v
+ }
+
+ // IPv6 function to create a packet and send it to the stack.
+ // The packet should be errant in a way that causes the stack to send an
+ // ICMP error response and have enough data to allow the testing of the
+ // inclusion of the errant packet. Use `unknown next header' to generate
+ // the error.
+ rxIPv6Bad := func(e *channel.Endpoint, src tcpip.Address, payload []byte) buffer.View {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(payload)),
+ TransportProtocol: reservedProtocol,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: localIPv6Addr,
+ })
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(buffer.View(payload))
+ // Take a copy before InjectInbound takes ownership of vv
+ // as vv may be changed during the call.
+ v := vv.ToView()
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ return v
+ }
+
+ v4Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
+ // We already know the entire packet is the right size so we can use its
+ // length to calculate the right payload size to check.
+ expectedPayloadLength := pkt.Size() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
+ checker.IPv4(t, stack.PayloadSince(pkt.NetworkHeader()),
+ checker.SrcAddr(localIPv4Addr),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+expectedPayloadLength)),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4ProtoUnreachable),
+ checker.ICMPv4Payload(payload[:expectedPayloadLength]),
+ ),
+ )
+ }
+
+ v6Checker := func(t *testing.T, pkt *stack.PacketBuffer, payload buffer.View) {
+ // We already know the entire packet is the right size so we can use its
+ // length to calculate the right payload size to check.
+ expectedPayloadLength := pkt.Size() - header.IPv6MinimumSize - header.ICMPv6MinimumSize
+ checker.IPv6(t, stack.PayloadSince(pkt.NetworkHeader()),
+ checker.SrcAddr(localIPv6Addr),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectedPayloadLength)),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6ParamProblem),
+ checker.ICMPv6Code(header.ICMPv6UnknownHeader),
+ checker.ICMPv6Payload(payload[:expectedPayloadLength]),
+ ),
+ )
+ }
+ tests := []struct {
+ name string
+ srcAddress tcpip.Address
+ injector func(*channel.Endpoint, tcpip.Address, []byte) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, buffer.View)
+ payloadLength int // Not including IP header.
+ linkMTU uint32 // Largest IP packet that the link can send as payload.
+ replyLength int // Total size of IP/ICMP packet expected back.
+ }{
+ {
+ name: "IPv4 exact match",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4 - replyHeaderLength4,
+ linkMTU: targetSize4,
+ replyLength: targetSize4,
+ },
+ {
+ name: "IPv4 larger MTU",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4,
+ linkMTU: targetSize4 + 1000,
+ replyLength: targetSize4,
+ },
+ {
+ name: "IPv4 smaller MTU",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4,
+ linkMTU: targetSize4 - 50,
+ replyLength: targetSize4 - 50,
+ },
+ {
+ name: "IPv4 payload exceeds",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4 + 10,
+ linkMTU: targetSize4,
+ replyLength: targetSize4,
+ },
+ {
+ name: "IPv4 1 byte less",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: targetSize4 - replyHeaderLength4 - 1,
+ linkMTU: targetSize4,
+ replyLength: targetSize4 - 1,
+ },
+ {
+ name: "IPv4 No payload",
+ srcAddress: remoteIPv4Addr,
+ injector: rxIPv4Bad,
+ checker: v4Checker,
+ payloadLength: 0,
+ linkMTU: targetSize4,
+ replyLength: replyHeaderLength4,
+ },
+ {
+ name: "IPv6 exact match",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6 - replyHeaderLength6,
+ linkMTU: targetSize6,
+ replyLength: targetSize6,
+ },
+ {
+ name: "IPv6 larger MTU",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6,
+ linkMTU: targetSize6 + 400,
+ replyLength: targetSize6,
+ },
+ // NB. No "smaller MTU" test here as less than 1280 is not permitted
+ // in IPv6.
+ {
+ name: "IPv6 payload exceeds",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6,
+ linkMTU: targetSize6,
+ replyLength: targetSize6,
+ },
+ {
+ name: "IPv6 1 byte less",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: targetSize6 - replyHeaderLength6 - 1,
+ linkMTU: targetSize6,
+ replyLength: targetSize6 - 1,
+ },
+ {
+ name: "IPv6 no payload",
+ srcAddress: remoteIPv6Addr,
+ injector: rxIPv6Bad,
+ checker: v6Checker,
+ payloadLength: 0,
+ linkMTU: targetSize6,
+ replyLength: replyHeaderLength6,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := buildDummyStackWithLinkEndpoint(t, test.linkMTU)
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(test.payloadLength)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+ // Default routes for IPv4&6 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP error Reply.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+ v := test.injector(e, test.srcAddress, payload)
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ if got, want := pkt.Pkt.Size(), test.replyLength; got != want {
+ t.Fatalf("got %d bytes of icmp error packet, want %d", got, want)
+ }
+ test.checker(t, pkt.Pkt, v)
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 32f53f217..330a7d170 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -8,6 +8,7 @@ go_library(
"icmp.go",
"igmp.go",
"ipv4.go",
+ "stats.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -49,3 +50,15 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "stats_test",
+ size = "small",
+ srcs = ["stats_test.go"],
+ library = ":ipv4",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 8e392f86c..3d93a2cd0 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
package ipv4
import (
- "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -62,21 +61,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4.PacketsReceived
+ received := e.stats.icmp.packetsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.ICMPv4(v)
// Only do in-stack processing if the checksum is correct.
if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
- received.Invalid.Increment()
+ received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
// we are the only handler, however other types do not cope well with
@@ -106,19 +104,14 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
} else {
op = &optionUsageReceive{}
}
- aux, tmp, err := e.processIPOptions(pkt, opts, op)
- if err != nil {
- switch {
- case
- errors.Is(err, header.ErrIPv4OptDuplicate),
- errors.Is(err, errIPv4RecordRouteOptInvalidLength),
- errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
- errors.Is(err, errIPv4TimestampOptInvalidLength),
- errors.Is(err, errIPv4TimestampOptInvalidPointer),
- errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
- stats.MalformedRcvdPackets.Increment()
- stats.IP.MalformedPacketsReceived.Increment()
+ tmp, optProblem := e.processIPOptions(pkt, opts, op)
+ if optProblem != nil {
+ if optProblem.NeedICMP {
+ _ = e.protocol.returnError(&icmpReasonParamProblem{
+ pointer: optProblem.Pointer,
+ }, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stats.ip.MalformedPacketsReceived.Increment()
}
return
}
@@ -128,11 +121,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
- received.Echo.Increment()
+ received.echo.Increment()
- sent := stats.ICMP.V4.PacketsSent
+ sent := e.stats.icmp.packetsSent
if !e.protocol.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return
}
@@ -213,18 +206,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.EchoReply.Increment()
+ sent.echoReply.Increment()
case header.ICMPv4EchoReply:
- received.EchoReply.Increment()
+ received.echoReply.Increment()
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
- received.DstUnreachable.Increment()
+ received.dstUnreachable.Increment()
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
@@ -243,31 +236,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
case header.ICMPv4SrcQuench:
- received.SrcQuench.Increment()
+ received.srcQuench.Increment()
case header.ICMPv4Redirect:
- received.Redirect.Increment()
+ received.redirect.Increment()
case header.ICMPv4TimeExceeded:
- received.TimeExceeded.Increment()
+ received.timeExceeded.Increment()
case header.ICMPv4ParamProblem:
- received.ParamProblem.Increment()
+ received.paramProblem.Increment()
case header.ICMPv4Timestamp:
- received.Timestamp.Increment()
+ received.timestamp.Increment()
case header.ICMPv4TimestampReply:
- received.TimestampReply.Increment()
+ received.timestampReply.Increment()
case header.ICMPv4InfoRequest:
- received.InfoRequest.Increment()
+ received.infoRequest.Increment()
case header.ICMPv4InfoReply:
- received.InfoReply.Increment()
+ received.infoReply.Increment()
default:
- received.Invalid.Increment()
+ received.invalid.Increment()
}
}
@@ -317,7 +310,7 @@ func (*icmpReasonParamProblem) isICMPReason() {}
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
origIPHdr := header.IPv4(pkt.NetworkHeader().View())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
@@ -379,9 +372,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4.PacketsSent
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[pkt.NICID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+
if !p.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return nil
}
@@ -435,13 +436,13 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// systems implement the RFC 1812 definition and not the original
// requirement. We treat 8 bytes as the minimum but will try send more.
mtu := int(route.MTU())
- if mtu > header.IPv4MinimumProcessableDatagramSize {
- mtu = header.IPv4MinimumProcessableDatagramSize
+ const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize
+ if mtu > maxIPData {
+ mtu = maxIPData
}
- headerLen := int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize
- available := int(mtu) - headerLen
+ available := mtu - header.ICMPv4MinimumSize
- if available < header.IPv4MinimumSize+header.ICMPv4MinimumErrorPayloadSize {
+ if available < len(origIPHdr)+header.ICMPv4MinimumErrorPayloadSize {
return nil
}
@@ -464,36 +465,36 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
payload.CapLength(payloadLen)
icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
+ ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv4MinimumSize,
Data: payload,
})
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter *tcpip.StatCounter
+ var counter tcpip.MultiCounterStat
switch reason := reason.(type) {
case *icmpReasonPortUnreachable:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonProtoUnreachable:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonParamProblem:
icmpHdr.SetType(header.ICMPv4ParamProblem)
icmpHdr.SetCode(header.ICMPv4UnusedCode)
icmpHdr.SetPointer(reason.pointer)
- counter = sent.ParamProblem
+ counter = sent.paramProblem
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
@@ -508,7 +509,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
},
icmpPkt,
); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
counter.Increment()
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index d9b5fe6ed..4cd0b3256 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -103,7 +103,7 @@ func (igmp *igmpState) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
igmpType := header.IGMPv2MembershipReport
if igmp.v1Present() {
igmpType = header.IGMPv1MembershipReport
@@ -114,7 +114,7 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Erro
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
// As per RFC 2236 Section 6, Page 8: "If the interface state says the
// Querier is running IGMPv1, this action SHOULD be skipped. If the flag
// saying we were the last host to report is cleared, this action MAY be
@@ -149,51 +149,49 @@ func (igmp *igmpState) init(ep *endpoint) {
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
- stats := igmp.ep.protocol.stack.Stats()
- received := stats.IGMP.PacketsReceived
+ received := igmp.ep.stats.igmp.packetsReceived
headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.IGMP(headerView)
- // Temporarily reset the checksum field to 0 in order to calculate the proper
- // checksum.
- wantChecksum := h.Checksum()
- h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- h.SetChecksum(wantChecksum)
-
- if gotChecksum != wantChecksum {
- received.ChecksumErrors.Increment()
+ // As per RFC 1071 section 1.3,
+ //
+ // To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xFFFF {
+ received.checksumErrors.Increment()
return
}
switch h.Type() {
case header.IGMPMembershipQuery:
- received.MembershipQuery.Increment()
+ received.membershipQuery.Increment()
if len(headerView) < header.IGMPQueryMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
case header.IGMPv1MembershipReport:
- received.V1MembershipReport.Increment()
+ received.v1MembershipReport.Increment()
if len(headerView) < header.IGMPReportMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPv2MembershipReport:
- received.V2MembershipReport.Increment()
+ received.v2MembershipReport.Increment()
if len(headerView) < header.IGMPReportMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPLeaveGroup:
- received.LeaveGroup.Increment()
+ received.leaveGroup.Increment()
// As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
// Report, are ignored in all states"
@@ -201,7 +199,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
// As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
// be silently ignored. New message types may be used by newer versions of
// IGMP, by multicast routing protocols, or other uses."
- received.Unrecognized.Increment()
+ received.unrecognized.Increment()
}
}
@@ -244,7 +242,7 @@ func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
// writePacket assembles and sends an IGMP packet.
//
// Precondition: igmp.ep.mu must be read locked.
-func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) {
+func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, tcpip.Error) {
igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize))
igmpData.SetType(igmpType)
igmpData.SetGroupAddress(groupAddress)
@@ -272,18 +270,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ sentStats := igmp.ep.stats.igmp.packetsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sentStats.Dropped.Increment()
+ sentStats.dropped.Increment()
return false, err
}
switch igmpType {
case header.IGMPv1MembershipReport:
- sentStats.V1MembershipReport.Increment()
+ sentStats.v1MembershipReport.Increment()
case header.IGMPv2MembershipReport:
- sentStats.V2MembershipReport.Increment()
+ sentStats.v2MembershipReport.Increment()
case header.IGMPLeaveGroup:
- sentStats.LeaveGroup.Increment()
+ sentStats.leaveGroup.Increment()
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
@@ -295,7 +293,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
// messages.
//
// If the group already exists in the membership map, returns
-// tcpip.ErrDuplicateAddress.
+// *tcpip.ErrDuplicateAddress.
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
@@ -314,13 +312,13 @@ func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
// if required.
//
// Precondition: igmp.ep.mu must be locked.
-func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of IGMP, but remains
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index bb25a76fe..e5c80699d 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,9 +16,9 @@
package ipv4
import (
- "errors"
"fmt"
"math"
+ "reflect"
"sync/atomic"
"time"
@@ -73,6 +73,7 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
+ stats sharedStats
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -114,18 +115,36 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
e.mu.addressableEndpointState.Init(e)
e.mu.igmp.init(e)
e.mu.Unlock()
+
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+
+ stackStats := p.stack.Stats()
+ e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP)
+ e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4)
+ e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP)
+
+ p.mu.Lock()
+ p.mu.eps[nic.ID()] = e
+ p.mu.Unlock()
+
return e
}
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, nicID)
+}
+
// Enable implements stack.NetworkEndpoint.
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If the NIC is not enabled, the endpoint can't do anything meaningful so
// don't enable the endpoint.
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
// If the endpoint is already enabled, there is nothing for it to do.
@@ -193,7 +212,9 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
}
@@ -202,7 +223,9 @@ func (e *endpoint) disableLocked() {
e.mu.igmp.softLeaveAll()
// The address may have already been removed.
- if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err))
}
@@ -237,7 +260,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) *tcpip.Error {
+func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) tcpip.Error {
hdrLen := header.IPv4MinimumSize
var optLen int
if options != nil {
@@ -245,12 +268,12 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
}
hdrLen += optLen
if hdrLen > header.IPv4MaximumHeaderSize {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen))
length := pkt.Size()
if length > math.MaxUint16 {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
// RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
// datagrams. Since the DF bit is never being set here, all datagrams
@@ -275,7 +298,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
// original packet.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
// Round the MTU down to align to 8 bytes.
fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -295,17 +318,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
return err
}
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
- e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
+ e.stats.ip.IPTablesOutputDropped.Increment()
return nil
}
@@ -334,7 +357,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacket(r, gso, pkt, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
@@ -349,35 +372,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ stats := e.stats.ip
+
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
})
- r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ stats.PacketsSent.IncrementBy(uint64(sent))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
- r.Stats().IP.PacketsSent.Increment()
+ stats.PacketsSent.Increment()
return nil
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
@@ -385,6 +410,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ stats := e.stats.ip
+
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
return 0, err
@@ -392,7 +419,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
@@ -400,7 +427,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pkt
- if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
@@ -413,21 +440,21 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
- r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
@@ -451,36 +478,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
+ stats.PacketsSent.IncrementBy(uint64(n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
}
n++
}
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
hdrLen := header.IPv4(h).HeaderLength()
if hdrLen < header.IPv4MinimumSize {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
h, ok = pkt.Data.PullUp(int(hdrLen))
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv4(h)
@@ -518,14 +545,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// wire format. We also want to check if the header's fields are valid before
// sending the packet.
if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv4(pkt.NetworkHeader().View())
ttl := h.TTL()
if ttl == 0 {
@@ -545,7 +572,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
networkEndpoint.(*endpoint).handlePacket(pkt)
return nil
}
- if err != tcpip.ErrBadAddress {
+ if _, ok := err.(*tcpip.ErrBadAddress); !ok {
return err
}
@@ -577,19 +604,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- stats.IP.PacketsReceived.Increment()
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
if !e.isEnabled() {
- stats.IP.DisabledPacketsReceived.Increment()
+ stats.DisabledPacketsReceived.Increment()
return
}
// Loopback traffic skips the prerouting chain.
if !e.nic.IsLoopback() {
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesPreroutingDropped.Increment()
+ stats.IPTablesPreroutingDropped.Increment()
return
}
}
@@ -601,11 +630,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
- stats := e.protocol.stack.Stats()
+ stats := e.stats
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -631,7 +660,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -643,7 +672,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// be one of its own IP addresses (but not a broadcast or
// multicast address).
if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.ip.InvalidSourceAddressesReceived.Increment()
return
}
// Make sure the source address is not a subnet-local broadcast address.
@@ -651,7 +680,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
subnet := addressEndpoint.Subnet()
addressEndpoint.DecRef()
if subnet.IsBroadcast(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.ip.InvalidSourceAddressesReceived.Increment()
return
}
}
@@ -664,7 +693,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
+ stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -674,9 +703,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesInputDropped.Increment()
+ stats.ip.IPTablesInputDropped.Increment()
return
}
@@ -684,8 +714,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -698,8 +728,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
@@ -720,8 +750,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt,
)
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
if !ready {
@@ -734,7 +764,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
h.SetFlagsFragmentOffset(0, 0)
}
- stats.IP.PacketsDelivered.Increment()
+ stats.ip.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
@@ -755,19 +785,13 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// TODO(gvisor.dev/issue/4586):
// When we add forwarding support we should use the verified options
// rather than just throwing them away.
- aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{})
- if err != nil {
- switch {
- case
- errors.Is(err, header.ErrIPv4OptDuplicate),
- errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
- errors.Is(err, errIPv4RecordRouteOptInvalidLength),
- errors.Is(err, errIPv4TimestampOptInvalidLength),
- errors.Is(err, errIPv4TimestampOptInvalidPointer),
- errors.Is(err, errIPv4TimestampOptOverflow):
- _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
- stats.MalformedRcvdPackets.Increment()
- stats.IP.MalformedPacketsReceived.Increment()
+ if _, optProblem := e.processIPOptions(pkt, opts, &optionUsageReceive{}); optProblem != nil {
+ if optProblem.NeedICMP {
+ _ = e.protocol.returnError(&icmpReasonParamProblem{
+ pointer: optProblem.Pointer,
+ }, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
}
return
}
@@ -800,10 +824,12 @@ func (e *endpoint) Close() {
e.disableLocked()
e.mu.addressableEndpointState.Cleanup()
+
+ e.protocol.forgetEndpoint(e.nic.ID())
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
e.mu.Lock()
defer e.mu.Unlock()
@@ -815,7 +841,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.mu.addressableEndpointState.RemovePermanentAddress(addr)
@@ -872,7 +898,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
@@ -881,9 +907,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error {
if !header.IsV4MulticastAddress(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
e.mu.igmp.joinGroup(addr)
@@ -891,7 +917,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
@@ -900,7 +926,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error {
return e.mu.igmp.leaveGroup(addr)
}
@@ -911,6 +937,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
return e.mu.igmp.isInGroup(addr)
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -918,6 +949,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
+ mu struct {
+ sync.RWMutex
+
+ // eps is keyed by NICID to allow protocol methods to retrieve an endpoint
+ // when handling a packet, by looking at which NIC handled the packet.
+ eps map[tcpip.NICID]*endpoint
+ }
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -960,24 +999,24 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
p.SetDefaultTTL(uint8(*v))
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -1023,9 +1062,9 @@ func (p *protocol) SetForwarding(v bool) {
// calculateNetworkMTU calculates the network-layer payload MTU based on the
// link-layer payload mtu.
-func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) {
if linkMTU < header.IPv4MinimumMTU {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
// As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
@@ -1033,7 +1072,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Erro
// The maximal internet header is 60 octets, and a typical internet header
// is 20 octets, allowing a margin for headers of higher level protocols.
if networkHeaderSize > header.IPv4MaximumHeaderSize {
- return 0, tcpip.ErrMalformedHeader
+ return 0, &tcpip.ErrMalformedHeader{}
}
networkMTU := linkMTU
@@ -1095,6 +1134,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
options: opts,
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ p.mu.eps = make(map[tcpip.NICID]*endpoint)
return p
}
}
@@ -1192,16 +1232,9 @@ func (*optionUsageEcho) actions() optionActions {
}
}
-var (
- errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length")
- errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer")
- errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp")
- errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags")
-)
-
// handleTimestamp does any required processing on a Timestamp option
// in place.
-func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) {
+func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) *header.IPv4OptParameterProblem {
flags := tsOpt.Flags()
var entrySize uint8
switch flags {
@@ -1212,7 +1245,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
header.IPv4OptionTimestampWithPredefinedIPFlag:
entrySize = header.IPv4OptionTimestampWithAddrSize
default:
- return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSOFLWAndFLGOffset,
+ NeedICMP: true,
+ }
}
pointer := tsOpt.Pointer()
@@ -1220,7 +1256,10 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
// Since the pointer is 1 based, and the header is 4 bytes long the
// pointer must point beyond the header therefore 4 or less is bad.
if pointer <= header.IPv4OptionTimestampHdrLength {
- return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSPointerOffset,
+ NeedICMP: true,
+ }
}
// To simplify processing below, base further work on the array of timestamps
// beyond the header, rather than on the whole option. Also to aid
@@ -1254,14 +1293,17 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
// timestamp, but the overflow count is incremented by one.
if flags == header.IPv4OptionTimestampWithPredefinedIPFlag {
// By definition we have nothing to do.
- return 0, nil
+ return nil
}
if tsOpt.IncOverflow() != 0 {
- return 0, nil
+ return nil
}
// The overflow count is also full.
- return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSOFLWAndFLGOffset,
+ NeedICMP: true,
+ }
}
if nextSlot+entrySize > dataLength {
// The data area isn't full but there isn't room for a new entry.
@@ -1280,32 +1322,36 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres
if dataLength%entrySize != 0 {
// The Data section size should be a multiple of the expected
// timestamp entry size.
- return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptionLengthOffset,
+ NeedICMP: false,
+ }
}
// If the size is OK, the pointer must be corrupted.
}
- return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptTSPointerOffset,
+ NeedICMP: true,
+ }
}
if usage.actions().timestamp == optionProcess {
tsOpt.UpdateTimestamp(localAddress, clock)
}
- return 0, nil
+ return nil
}
-var (
- errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route")
- errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route")
-)
-
// handleRecordRoute checks and processes a Record route option. It is much
// like the timestamp type 1 option, but without timestamps. The passed in
// address is stored in the option in the correct spot if possible.
-func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) {
+func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) *header.IPv4OptParameterProblem {
optlen := rrOpt.Size()
if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength {
- return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptionLengthOffset,
+ NeedICMP: true,
+ }
}
pointer := rrOpt.Pointer()
@@ -1315,7 +1361,10 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// Since the pointer is 1 based, and the header is 3 bytes long the
// pointer must point beyond the header therefore 3 or less is bad.
if pointer <= header.IPv4OptionRecordRouteHdrLength {
- return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptRRPointerOffset,
+ NeedICMP: true,
+ }
}
// RFC 791 page 21 says
@@ -1332,7 +1381,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// of these words is a copy/paste error from the timestamp option where
// there are two failure reasons given.
if pointer > optlen {
- return 0, nil
+ return nil
}
// The data area isn't full but there isn't room for a new entry.
@@ -1357,17 +1406,23 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// }
if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 {
// Length is bad, not on integral number of slots.
- return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptionLengthOffset,
+ NeedICMP: true,
+ }
}
// If not length, the fault must be with the pointer.
}
- return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ return &header.IPv4OptParameterProblem{
+ Pointer: header.IPv4OptRRPointerOffset,
+ NeedICMP: true,
+ }
}
if usage.actions().recordRoute == optionVerify {
- return 0, nil
+ return nil
}
rrOpt.StoreAddress(localAddress)
- return 0, nil
+ return nil
}
// processIPOptions parses the IPv4 options and produces a new set of options
@@ -1378,8 +1433,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// - The location of an error if there was one (or 0 if no error)
// - If there is an error, information as to what it was was.
// - The replacement option set.
-func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
- stats := e.protocol.stack.Stats()
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, *header.IPv4OptParameterProblem) {
+ stats := e.stats.ip
opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
@@ -1392,21 +1447,23 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
// This will need tweaking when we start really forwarding packets
// as we may need to get two addresses, for rx and tx interfaces.
// We will also have to take usage into account.
- prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
+ prefixedAddress, ok := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
localAddress := prefixedAddress.Address
- if err != nil {
+ if !ok {
h := header.IPv4(pkt.NetworkHeader().View())
dstAddr := h.DestinationAddress()
if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
- return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
+ return nil, &header.IPv4OptParameterProblem{
+ NeedICMP: false,
+ }
}
localAddress = dstAddr
}
for {
- option, done, err := optIter.Next()
- if done || err != nil {
- return optIter.ErrCursor, optIter.Finalize(), err
+ option, done, optProblem := optIter.Next()
+ if done || optProblem != nil {
+ return optIter.Finalize(), optProblem
}
optType := option.Type()
if optType == header.IPv4OptionNOPType {
@@ -1415,44 +1472,47 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
}
if optType == header.IPv4OptionListEndType {
optIter.PushNOPOrEnd(optType)
- return 0 /* errCursor */, optIter.Finalize(), nil /* err */
+ return optIter.Finalize(), nil
}
// check for repeating options (multiple NOPs are OK)
if seenOptions[optType] {
- return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate
+ return nil, &header.IPv4OptParameterProblem{
+ Pointer: optIter.ErrCursor,
+ NeedICMP: true,
+ }
}
seenOptions[optType] = true
optLen := int(option.Size())
switch option := option.(type) {
case *header.IPv4OptionTimestamp:
- stats.IP.OptionTSReceived.Increment()
+ stats.OptionTSReceived.Increment()
if usage.actions().timestamp != optionRemove {
clock := e.protocol.stack.Clock()
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
- offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
- if err != nil {
- return optIter.ErrCursor + offset, nil, err
+ if optProblem := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage); optProblem != nil {
+ optProblem.Pointer += optIter.ErrCursor
+ return nil, optProblem
}
optIter.ConsumeBuffer(optLen)
}
case *header.IPv4OptionRecordRoute:
- stats.IP.OptionRRReceived.Increment()
+ stats.OptionRRReceived.Increment()
if usage.actions().recordRoute != optionRemove {
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
- offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage)
- if err != nil {
- return optIter.ErrCursor + offset, nil, err
+ if optProblem := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage); optProblem != nil {
+ optProblem.Pointer += optIter.ErrCursor
+ return nil, optProblem
}
optIter.ConsumeBuffer(optLen)
}
default:
- stats.IP.OptionUnknownReceived.Increment()
+ stats.OptionUnknownReceived.Increment()
if usage.actions().unknown == optionPass {
newBuffer := optIter.RemainingBuffer()[:optLen]
// Arguments already heavily checked.. ignore result.
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index a9e137c24..ed5899f0b 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -78,8 +78,11 @@ func TestExcludeBroadcast(t *testing.T) {
defer ep.Close()
// Cannot connect using a broadcast address as the source.
- if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ {
+ err := ep.Connect(randomAddr)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got ep.Connect(...) = %v, want = %v", err, &tcpip.ErrNoRoute{})
+ }
}
// However, we can bind to a broadcast address to listen.
@@ -270,6 +273,11 @@ func TestIPv4Sanity(t *testing.T) {
nicID = 1
randomSequence = 123
randomIdent = 42
+ // In some cases Linux sets the error pointer to the start of the option
+ // (offset 0) instead of the actual wrong value, which is the length byte
+ // (offset 1). For compatibility we must do the same. Use this constant
+ // to indicate where this happens.
+ pointerOffsetForInvalidLength = 0
)
var (
ipv4Addr = tcpip.AddressWithPrefix{
@@ -439,6 +447,21 @@ func TestIPv4Sanity(t *testing.T) {
replyOptions: header.IPv4Options{},
},
{
+ name: "bad option - no length",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 1, 1, 1, 68,
+ // ^-start of timestamp.. but no length..
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 3,
+ },
+ {
name: "bad option - length 0",
maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
@@ -448,7 +471,27 @@ func TestIPv4Sanity(t *testing.T) {
// ^
1, 2, 3, 4,
},
- shouldFail: true,
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
+ },
+ {
+ name: "bad option - length 1",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: header.IPv4Options{
+ 68, 1, 9, 0,
+ // ^
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
},
{
name: "bad option - length big",
@@ -462,7 +505,11 @@ func TestIPv4Sanity(t *testing.T) {
// space is not possible. (Second byte)
1, 2, 3, 4,
},
- shouldFail: true,
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
},
{
// This tests for some linux compatible behaviour.
@@ -484,7 +531,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
},
{
name: "multiple type 0 with room",
@@ -589,7 +636,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
},
{
name: "valid timestamp pointer",
@@ -624,7 +671,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptTSPointerOffset,
},
// End of option list with illegal option after it, which should be ignored.
{
@@ -636,24 +683,31 @@ func TestIPv4Sanity(t *testing.T) {
68, 12, 13, 0x11,
192, 168, 1, 12,
1, 2, 3, 4,
- 0, 10, 3, 99,
+ 0, 10, 3, 99, // EOL followed by junk
},
replyOptions: header.IPv4Options{
68, 12, 13, 0x21,
192, 168, 1, 12,
1, 2, 3, 4,
- 0, 0, 0, 0, // 3 bytes unknown option
- }, // ^ End of options hides following bytes.
+ 0, // End of Options hides following bytes.
+ 0, 0, 0, // 3 bytes unknown option removed.
+ },
},
{
- // Timestamp with a size too small.
+ // Timestamp with a size much too small.
name: "timestamp truncated",
maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
- options: header.IPv4Options{68, 1, 0, 0},
- // ^ Smallest possible is 8.
- shouldFail: true,
+ options: header.IPv4Options{
+ 68, 1, 0, 0,
+ // ^ Smallest possible is 8. Linux points at the 68.
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + pointerOffsetForInvalidLength,
},
{
name: "single record route with room",
@@ -751,7 +805,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
},
{
// Pointer must be 4 or more as it must point past the 3 byte header
@@ -769,7 +823,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
},
{
// Pointer must be 4 or more as it must point past the 3 byte header
@@ -808,8 +862,7 @@ func TestIPv4Sanity(t *testing.T) {
expectErrorICMP: true,
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
- paramProblemPointer: header.IPv4MinimumSize + 2,
- replyOptions: header.IPv4Options{},
+ paramProblemPointer: header.IPv4MinimumSize + header.IPv4OptRRPointerOffset,
},
{
name: "duplicate record route",
@@ -828,7 +881,6 @@ func TestIPv4Sanity(t *testing.T) {
ICMPType: header.ICMPv4ParamProblem,
ICMPCode: header.ICMPv4UnusedCode,
paramProblemPointer: header.IPv4MinimumSize + 7,
- replyOptions: header.IPv4Options{},
},
}
@@ -884,7 +936,6 @@ func TestIPv4Sanity(t *testing.T) {
if test.maxTotalLength < totalLen {
totalLen = test.maxTotalLength
}
-
ip.Encode(&header.IPv4Fields{
TotalLength: totalLen,
Protocol: test.transportProtocol,
@@ -1328,8 +1379,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize int
allowPackets int
outgoingErrors int
- mockError *tcpip.Error
- wantError *tcpip.Error
+ mockError tcpip.Error
+ wantError tcpip.Error
}{
{
description: "No frag",
@@ -1338,8 +1389,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 0,
outgoingErrors: 1,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag",
@@ -1348,8 +1399,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 0,
outgoingErrors: 3,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on second frag",
@@ -1358,8 +1409,8 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 0,
allowPackets: 1,
outgoingErrors: 2,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag MTU smaller than header",
@@ -1368,8 +1419,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize: 500,
allowPackets: 0,
outgoingErrors: 4,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error when MTU is smaller than IPv4 minimum MTU",
@@ -1379,7 +1430,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrInvalidEndpointState,
+ wantError: &tcpip.ErrInvalidEndpointState{},
},
}
@@ -1393,8 +1444,8 @@ func TestFragmentationErrors(t *testing.T) {
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.wantError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ if diff := cmp.Diff(ft.wantError, err); diff != "" {
+ t.Fatalf("unexpected error from r.WritePacket(_, _, _), (-want, +got):\n%s", diff)
}
if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
@@ -2427,8 +2478,9 @@ func TestReceiveFragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("(last) got Read = (%#v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -2506,11 +2558,11 @@ func TestWriteStats(t *testing.T) {
// Parameterize the tests to run with both WritePacket and WritePackets.
writers := []struct {
name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
}{
{
name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
@@ -2522,7 +2574,7 @@ func TestWriteStats(t *testing.T) {
},
}, {
name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
},
},
@@ -2532,7 +2584,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -2608,7 +2660,7 @@ func (*limitedMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
if lm.limit == 0 {
return true, false
}
diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go
new file mode 100644
index 000000000..bee72c649
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/stats.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to the IPv4 protocol family.
+type Stats struct {
+ // IP holds IPv4 statistics.
+ IP tcpip.IPStats
+
+ // IGMP holds IGMP statistics.
+ IGMP tcpip.IGMPStats
+
+ // ICMP holds ICMPv4 statistics.
+ ICMP tcpip.ICMPv4Stats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*Stats) IsNetworkEndpointStats() {}
+
+// IPStats implements stack.IPNetworkEndointStats
+func (s *Stats) IPStats() *tcpip.IPStats {
+ return &s.IP
+}
+
+type sharedStats struct {
+ localStats Stats
+ ip ip.MultiCounterIPStats
+ icmp multiCounterICMPv4Stats
+ igmp multiCounterIGMPStats
+}
+
+// LINT.IfChange(multiCounterICMPv4PacketStats)
+
+type multiCounterICMPv4PacketStats struct {
+ echo tcpip.MultiCounterStat
+ echoReply tcpip.MultiCounterStat
+ dstUnreachable tcpip.MultiCounterStat
+ srcQuench tcpip.MultiCounterStat
+ redirect tcpip.MultiCounterStat
+ timeExceeded tcpip.MultiCounterStat
+ paramProblem tcpip.MultiCounterStat
+ timestamp tcpip.MultiCounterStat
+ timestampReply tcpip.MultiCounterStat
+ infoRequest tcpip.MultiCounterStat
+ infoReply tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) {
+ m.echo.Init(a.Echo, b.Echo)
+ m.echoReply.Init(a.EchoReply, b.EchoReply)
+ m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
+ m.srcQuench.Init(a.SrcQuench, b.SrcQuench)
+ m.redirect.Init(a.Redirect, b.Redirect)
+ m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
+ m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
+ m.timestamp.Init(a.Timestamp, b.Timestamp)
+ m.timestampReply.Init(a.TimestampReply, b.TimestampReply)
+ m.infoRequest.Init(a.InfoRequest, b.InfoRequest)
+ m.infoReply.Init(a.InfoReply, b.InfoReply)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats)
+
+// LINT.IfChange(multiCounterICMPv4SentPacketStats)
+
+type multiCounterICMPv4SentPacketStats struct {
+ multiCounterICMPv4PacketStats
+ dropped tcpip.MultiCounterStat
+ rateLimited tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) {
+ m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+ m.rateLimited.Init(a.RateLimited, b.RateLimited)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats)
+
+// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats)
+
+type multiCounterICMPv4ReceivedPacketStats struct {
+ multiCounterICMPv4PacketStats
+ invalid tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) {
+ m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
+ m.invalid.Init(a.Invalid, b.Invalid)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats)
+
+// LINT.IfChange(multiCounterICMPv4Stats)
+
+type multiCounterICMPv4Stats struct {
+ packetsSent multiCounterICMPv4SentPacketStats
+ packetsReceived multiCounterICMPv4ReceivedPacketStats
+}
+
+func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4Stats)
+
+// LINT.IfChange(multiCounterIGMPPacketStats)
+
+type multiCounterIGMPPacketStats struct {
+ membershipQuery tcpip.MultiCounterStat
+ v1MembershipReport tcpip.MultiCounterStat
+ v2MembershipReport tcpip.MultiCounterStat
+ leaveGroup tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) {
+ m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery)
+ m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport)
+ m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport)
+ m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPPacketStats)
+
+// LINT.IfChange(multiCounterIGMPSentPacketStats)
+
+type multiCounterIGMPSentPacketStats struct {
+ multiCounterIGMPPacketStats
+ dropped tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) {
+ m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats)
+
+// LINT.IfChange(multiCounterIGMPReceivedPacketStats)
+
+type multiCounterIGMPReceivedPacketStats struct {
+ multiCounterIGMPPacketStats
+ invalid tcpip.MultiCounterStat
+ checksumErrors tcpip.MultiCounterStat
+ unrecognized tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) {
+ m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
+ m.invalid.Init(a.Invalid, b.Invalid)
+ m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors)
+ m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats)
+
+// LINT.IfChange(multiCounterIGMPStats)
+
+type multiCounterIGMPStats struct {
+ packetsSent multiCounterIGMPSentPacketStats
+ packetsReceived multiCounterIGMPReceivedPacketStats
+}
+
+func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPStats)
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
new file mode 100644
index 000000000..b28e7dcde
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkInterface
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func knownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ nic := testInterface{nicID: 1}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ var nicIDs []tcpip.NICID
+
+ proto.mu.Lock()
+ foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+
+ if !hasEndpointBeforeClose {
+ t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
+ }
+
+ ep.Close()
+
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // expected to be bound by a MultiCounterStat.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index afa45aefe..0c5f8d683 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -10,6 +10,7 @@ go_library(
"ipv6.go",
"mld.go",
"ndp.go",
+ "stats.go",
],
visibility = ["//visibility:public"],
deps = [
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 6ee162713..7298bd061 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -125,15 +125,14 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
- stats := e.protocol.stack.Stats().ICMP
- sent := stats.V6.PacketsSent
- received := stats.V6.PacketsReceived
+ sent := e.stats.icmp.packetsSent
+ received := e.stats.icmp.packetsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.ICMPv6(v)
@@ -147,7 +146,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
payload := pkt.Data.Clone(nil)
payload.TrimFront(len(h))
if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -165,10 +164,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch icmpType := h.Type(); icmpType {
case header.ICMPv6PacketTooBig:
- received.PacketTooBig.Increment()
+ received.packetTooBig.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
@@ -179,10 +178,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
case header.ICMPv6DstUnreachable:
- received.DstUnreachable.Increment()
+ received.dstUnreachable.Increment()
hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
@@ -194,9 +193,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
case header.ICMPv6NeighborSolicit:
- received.NeighborSolicit.Increment()
+ received.neighborSolicit.Increment()
if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -210,7 +209,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast
// address.
if header.IsV6MulticastAddress(targetAddr) {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -238,7 +237,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
+ default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
}
}
@@ -263,13 +264,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
if err != nil {
// Options are not valid as per the wire format, silently drop the
// packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
sourceLinkAddr, ok = getSourceLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
}
@@ -282,16 +283,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
unspecifiedSource := srcAddr == header.IPv6Any
if len(sourceLinkAddr) == 0 {
if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
} else if unspecifiedSource {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
} else if e.nud != nil {
e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(srcAddr, sourceLinkAddr)
}
// As per RFC 4861 section 7.1.1:
@@ -301,7 +302,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// - If the IP source address is the unspecified address, the IP
// destination address is a solicited-node multicast address.
if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -379,15 +380,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.NeighborAdvert.Increment()
+ sent.neighborAdvert.Increment()
case header.ICMPv6NeighborAdvert:
- received.NeighborAdvert.Increment()
+ received.neighborAdvert.Increment()
if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -414,16 +415,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- if err := e.dupTentativeAddrDetected(targetAddr); err != nil && err != tcpip.ErrBadAddress && err != tcpip.ErrInvalidEndpointState {
+ switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
+ return
+ default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
}
- return
}
it, err := na.Options().Iter(false /* check */)
if err != nil {
// If we have a malformed NDP NA option, drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -438,7 +441,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// DAD.
targetLinkAddr, ok := getTargetLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
+ return
+ }
+
+ // As per RFC 4861 section 7.1.2:
+ // A node MUST silently discard any received Neighbor Advertisement
+ // messages that do not satisfy all of the following validity checks:
+ // ...
+ // - If the IP Destination Address is a multicast address the
+ // Solicited flag is zero.
+ if header.IsV6MulticastAddress(dstAddr) && na.SolicitedFlag() {
+ received.invalid.Increment()
return
}
@@ -446,7 +460,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// address cache with the link address for the target of the message.
if e.nud == nil {
if len(targetLinkAddr) != 0 {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(targetAddr, targetLinkAddr)
}
return
}
@@ -458,10 +472,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
})
case header.ICMPv6EchoRequest:
- received.EchoRequest.Increment()
+ received.echoRequest.Increment()
icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -493,27 +507,27 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
TTL: r.DefaultTTL(),
TOS: stack.DefaultTOS,
}, replyPkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.EchoReply.Increment()
+ sent.echoReply.Increment()
case header.ICMPv6EchoReply:
- received.EchoReply.Increment()
+ received.echoReply.Increment()
if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
- received.TimeExceeded.Increment()
+ received.timeExceeded.Increment()
case header.ICMPv6ParamProblem:
- received.ParamProblem.Increment()
+ received.paramProblem.Increment()
case header.ICMPv6RouterSolicit:
- received.RouterSolicit.Increment()
+ received.routerSolicit.Increment()
//
// Validate the RS as per RFC 4861 section 6.1.1.
@@ -521,7 +535,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the NDP payload of sufficient size to hold a Router Solictation?
if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -530,7 +544,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the networking stack operating as a router?
if !stack.Forwarding(ProtocolNumber) {
// ... No, silently drop the packet.
- received.RouterOnlyPacketsDroppedByHost.Increment()
+ received.routerOnlyPacketsDroppedByHost.Increment()
return
}
@@ -540,13 +554,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
it, err := rs.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
sourceLinkAddr, ok := getSourceLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -557,7 +571,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// NOT be included when the source IP address is the unspecified address.
// Otherwise, it SHOULD be included on link layers that have addresses.
if srcAddr == header.IPv6Any {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -569,7 +583,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
case header.ICMPv6RouterAdvert:
- received.RouterAdvert.Increment()
+ received.routerAdvert.Increment()
//
// Validate the RA as per RFC 4861 section 6.1.2.
@@ -577,7 +591,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the NDP payload of sufficient size to hold a Router Advertisement?
if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -586,7 +600,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
// ...No, silently drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -596,13 +610,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
it, err := ra.Options().Iter(false /* check */)
if err != nil {
// Options are not valid as per the wire format, silently drop the packet.
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
sourceLinkAddr, ok := getSourceLinkAddr(it)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -638,26 +652,26 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
// link-layer address be modified due to receiving one of the above
// messages, the state SHOULD also be set to STALE to provide prompt
// verification that the path to the new link-layer address is working."
- received.RedirectMsg.Increment()
+ received.redirectMsg.Increment()
if !isNDPValid() {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
- received.MulticastListenerQuery.Increment()
+ received.multicastListenerQuery.Increment()
case header.ICMPv6MulticastListenerReport:
- received.MulticastListenerReport.Increment()
+ received.multicastListenerReport.Increment()
case header.ICMPv6MulticastListenerDone:
- received.MulticastListenerDone.Increment()
+ received.multicastListenerDone.Increment()
default:
panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
}
if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
@@ -676,7 +690,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
default:
- received.Unrecognized.Increment()
+ received.unrecognized.Increment()
}
}
@@ -688,26 +702,39 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) tcpip.Error {
+ nicID := nic.ID()
+
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[nicID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
remoteAddr := targetAddr
if len(remoteLinkAddr) == 0 {
remoteAddr = header.SolicitedNodeAddr(targetAddr)
remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr)
}
- r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- return err
+ if len(localAddr) == 0 {
+ addressEndpoint := netEP.AcquireOutgoingPrimaryAddress(remoteAddr, false /* allowExpired */)
+ if addressEndpoint == nil {
+ return &tcpip.ErrNetworkUnreachable{}
+ }
+
+ localAddr = addressEndpoint.AddressWithPrefix().Address
+ } else if p.stack.CheckLocalAddress(nicID, ProtocolNumber, localAddr) == 0 {
+ return &tcpip.ErrBadLocalAddress{}
}
- defer r.Release()
- r.ResolveWith(remoteLinkAddr)
optsSerializer := header.NDPOptionsSerializer{
header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize,
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.IPv6FixedHeaderSize + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
@@ -715,18 +742,23 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot
ns := header.NDPNeighborSolicit(packet.MessageBody())
ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
- packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ packet.SetChecksum(header.ICMPv6Checksum(packet, localAddr, remoteAddr, buffer.VectorisedView{}))
- stat := p.stack.Stats().ICMP.V6.PacketsSent
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddr, remoteAddr, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
- }, pkt); err != nil {
- stat.Dropped.Increment()
+ }, header.IPv6ExtHdrSerializer{}); err != nil {
+ panic(fmt.Sprintf("failed to add IP header: %s", err))
+ }
+
+ stat := netEP.stats.icmp.packetsSent
+
+ if err := nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil {
+ stat.dropped.Increment()
return err
}
- stat.NeighborSolicit.Increment()
+ stat.neighborSolicit.Increment()
return nil
}
@@ -796,7 +828,7 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {}
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error {
origIPHdr := header.IPv6(pkt.NetworkHeader().View())
origIPHdrSrc := origIPHdr.SourceAddress()
origIPHdrDst := origIPHdr.DestinationAddress()
@@ -863,10 +895,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- stats := p.stack.Stats().ICMP
- sent := stats.V6.PacketsSent
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[pkt.NICID]
+ p.mu.Unlock()
+ if !ok {
+ return &tcpip.ErrNotConnected{}
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+
if !p.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return nil
}
@@ -897,11 +936,11 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
// the error message packet exceed the minimum IPv6 MTU
// [IPv6].
mtu := int(route.MTU())
- if mtu > header.IPv6MinimumMTU {
- mtu = header.IPv6MinimumMTU
+ const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize
+ if mtu > maxIPv6Data {
+ mtu = maxIPv6Data
}
- headerLen := int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize
- available := int(mtu) - headerLen
+ available := mtu - header.ICMPv6ErrorHeaderSize
if available < header.IPv6MinimumSize {
return nil
}
@@ -915,31 +954,31 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
payload.CapLength(payloadLen)
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: headerLen,
+ ReserveHeaderBytes: int(route.MaxHeaderLength()) + header.ICMPv6ErrorHeaderSize,
Data: payload,
})
newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
- var counter *tcpip.StatCounter
+ var counter tcpip.MultiCounterStat
switch reason := reason.(type) {
case *icmpReasonParameterProblem:
icmpHdr.SetType(header.ICMPv6ParamProblem)
icmpHdr.SetCode(reason.code)
icmpHdr.SetTypeSpecific(reason.pointer)
- counter = sent.ParamProblem
+ counter = sent.paramProblem
case *icmpReasonPortUnreachable:
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonHopLimitExceeded:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6HopLimitExceeded)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv6TimeExceeded)
icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
@@ -953,7 +992,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
},
newPkt,
); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
counter.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index b1e6a70a2..db1c2e663 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "bytes"
"context"
"net"
"reflect"
@@ -22,6 +23,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -77,7 +79,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
return nil
}
@@ -91,16 +93,11 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st
return stack.TransportPacketHandled
}
-type stubLinkAddressCache struct {
- stack.LinkAddressCache
-}
+var _ stack.LinkAddressCache = (*stubLinkAddressCache)(nil)
-func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.NICID {
- return 0
-}
+type stubLinkAddressCache struct{}
-func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) {
-}
+func (*stubLinkAddressCache) AddLinkAddress(tcpip.Address, tcpip.LinkAddress) {}
type stubNUDHandler struct {
probeCount int
@@ -148,15 +145,15 @@ func (*testInterface) Promiscuous() bool {
return false
}
-func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt)
}
-func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
}
-func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
var r stack.RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
@@ -643,7 +640,6 @@ func TestLinkResolution(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
// doesn't provoke NDP discovery.
@@ -653,8 +649,12 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
- if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
- t.Fatalf("ep.Write(_): %s", err)
+ {
+ var r bytes.Reader
+ r.Reset(hdr.View())
+ if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
+ t.Fatalf("ep.Write(_): %s", err)
+ }
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
@@ -1283,7 +1283,7 @@ func TestLinkAddressRequest(t *testing.T) {
localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectedErr *tcpip.Error
+ expectedErr tcpip.Error
expectedRemoteAddr tcpip.Address
expectedRemoteLinkAddr tcpip.LinkAddress
}{
@@ -1321,79 +1321,80 @@ func TestLinkAddressRequest(t *testing.T) {
name: "Unicast with unassigned address",
localAddr: lladdr1,
remoteLinkAddr: linkAddr1,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
},
{
name: "Multicast with unassigned address",
localAddr: lladdr1,
remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrBadLocalAddress{},
},
{
name: "Unicast with no local address available",
remoteLinkAddr: linkAddr1,
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
},
{
name: "Multicast with no local address available",
remoteLinkAddr: "",
- expectedErr: tcpip.ErrNetworkUnreachable,
+ expectedErr: &tcpip.ErrNetworkUnreachable{},
},
}
for _, test := range tests {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
- })
- p := s.NetworkProtocolInstance(ProtocolNumber)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
- }
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
- linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := s.CreateNIC(nicID, linkEP); err != nil {
- t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
- }
- if len(test.nicAddr) != 0 {
- if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
- t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ }
}
- }
- // We pass a test network interface to LinkAddressRequest with the same NIC
- // ID and link endpoint used by the NIC we created earlier so that we can
- // mock a link address request and observe the packets sent to the link
- // endpoint even though the stack uses the real NIC.
- if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
- }
+ // We pass a test network interface to LinkAddressRequest with the same NIC
+ // ID and link endpoint used by the NIC we created earlier so that we can
+ // mock a link address request and observe the packets sent to the link
+ // endpoint even though the stack uses the real NIC.
+ err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID})
+ if diff := cmp.Diff(test.expectedErr, err); diff != "" {
+ t.Fatalf("unexpected error from p.LinkAddressRequest(%s, %s, %s, _), (-want, +got):\n%s", lladdr0, test.localAddr, test.remoteLinkAddr, diff)
+ }
- if test.expectedErr != nil {
- return
- }
+ if test.expectedErr != nil {
+ return
+ }
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
- if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
- }
- if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
- }
- if pkt.Route.LocalAddress != lladdr1 {
- t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
- }
- checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
- checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedRemoteAddr),
- checker.TTL(header.NDPHopLimit),
- checker.NDPNS(
- checker.NDPNSTargetAddress(lladdr0),
- checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
- ))
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ var want stack.RouteInfo
+ want.NetProto = ProtocolNumber
+ want.RemoteLinkAddress = test.expectedRemoteLinkAddr
+ if diff := cmp.Diff(want, pkt.Route, cmp.AllowUnexported(want)); diff != "" {
+ t.Errorf("route info mismatch (-want +got):\n%s", diff)
+ }
+ checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
+ checker.SrcAddr(lladdr1),
+ checker.DstAddr(test.expectedRemoteAddr),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(lladdr0),
+ checker.NDPNSOptions([]header.NDPOption{header.NDPSourceLinkLayerAddressOption(linkAddr0)}),
+ ))
+ })
}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index ae4a8f508..94043ed4e 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -20,6 +20,7 @@ import (
"fmt"
"hash/fnv"
"math"
+ "reflect"
"sort"
"sync/atomic"
"time"
@@ -177,6 +178,7 @@ type endpoint struct {
dispatcher stack.TransportDispatcher
protocol *protocol
stack *stack.Stack
+ stats sharedStats
// enabled is set to 1 when the endpoint is enabled and 0 when it is
// disabled.
@@ -305,17 +307,17 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
if addressEndpoint.GetKind() != stack.PermanentTentative {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
@@ -367,14 +369,14 @@ func (e *endpoint) transitionForwarding(forwarding bool) {
}
// Enable implements stack.NetworkEndpoint.
-func (e *endpoint) Enable() *tcpip.Error {
+func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If the NIC is not enabled, the endpoint can't do anything meaningful so
// don't enable the endpoint.
if !e.nic.Enabled() {
- return tcpip.ErrNotPermitted
+ return &tcpip.ErrNotPermitted{}
}
// If the endpoint is already enabled, there is nothing for it to do.
@@ -416,7 +418,7 @@ func (e *endpoint) Enable() *tcpip.Error {
//
// Addresses may have aleady completed DAD but in the time since the endpoint
// was last enabled, other devices may have acquired the same addresses.
- var err *tcpip.Error
+ var err tcpip.Error
e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
addr := addressEndpoint.AddressWithPrefix().Address
if !header.IsV6UnicastAddress(addr) {
@@ -497,7 +499,9 @@ func (e *endpoint) disableLocked() {
e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
}
@@ -553,11 +557,11 @@ func (e *endpoint) MaxHeaderLength() uint16 {
return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
-func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) *tcpip.Error {
+func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) tcpip.Error {
extHdrsLen := extensionHeaders.Length()
length := pkt.Size() + extensionHeaders.Length()
if length > math.MaxUint16 {
- return tcpip.ErrMessageTooLong
+ return &tcpip.ErrMessageTooLong{}
}
ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen))
ip.Encode(&header.IPv6Fields{
@@ -583,7 +587,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta
// fragments left to be processed. The IP header must already be present in the
// original packet. The transport header protocol number is required to avoid
// parsing the IPv6 extension headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) {
networkHeader := header.IPv6(pkt.NetworkHeader().View())
// TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
@@ -596,13 +600,13 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
// of 8 as per RFC 8200 section 4.5:
// Each complete fragment, except possibly the last ("rightmost") one, is
// an integer multiple of 8 octets long.
- return 0, 1, tcpip.ErrMessageTooLong
+ return 0, 1, &tcpip.ErrMessageTooLong{}
}
if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
// As per RFC 8200 Section 4.5, the Transport Header is expected to be small
// enough to fit in the first fragment.
- return 0, 1, tcpip.ErrMessageTooLong
+ return 0, 1, &tcpip.ErrMessageTooLong{}
}
pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
@@ -622,17 +626,17 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
- if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
+ if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */); err != nil {
return err
}
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
- e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
+ e.stats.ip.IPTablesOutputDropped.Increment()
return nil
}
@@ -660,7 +664,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */)
}
-func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error {
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error {
if r.Loop&stack.PacketLoop != 0 {
pkt := pkt.CloneToInbound()
if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
@@ -675,36 +679,37 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
if packetMustBeFragmented(pkt, networkMTU, gso) {
- sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
// WritePackets(). It'll be faster but cost more memory.
return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
})
- r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ stats.PacketsSent.IncrementBy(uint64(sent))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
- r.Stats().IP.PacketsSent.Increment()
+ stats.PacketsSent.Increment()
return nil
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
@@ -712,28 +717,29 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ stats := e.stats.ip
linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
- if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil {
+ if err := addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil {
return 0, err
}
networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
if packetMustBeFragmented(pb, networkMTU, gso) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pb
- if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
return nil
}); err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
// Remove the packet that was just fragmented and process the rest.
@@ -743,19 +749,19 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName)
+ outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
- r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
@@ -779,8 +785,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
+ stats.PacketsSent.IncrementBy(uint64(n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
@@ -788,17 +794,17 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
n++
}
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
// The packet already has an IP header, but there are a few required checks.
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
if !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
ip := header.IPv6(h)
@@ -823,14 +829,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// sending the packet.
proto, _, _, _, ok := parse.IPv6(pkt)
if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */)
}
// forwardPacket attempts to forward a packet to its final destination.
-func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
+func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv6(pkt.NetworkHeader().View())
hopLimit := h.HopLimit()
if hopLimit <= 1 {
@@ -852,7 +858,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
networkEndpoint.(*endpoint).handlePacket(pkt)
return nil
}
- if err != tcpip.ErrBadAddress {
+ if _, ok := err.(*tcpip.ErrBadAddress); !ok {
return err
}
@@ -882,19 +888,21 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- stats.IP.PacketsReceived.Increment()
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
if !e.isEnabled() {
- stats.IP.DisabledPacketsReceived.Increment()
+ stats.DisabledPacketsReceived.Increment()
return
}
// Loopback traffic skips the prerouting chain.
if !e.nic.IsLoopback() {
- if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesPreroutingDropped.Increment()
+ stats.IPTablesPreroutingDropped.Increment()
return
}
}
@@ -906,11 +914,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
- stats := e.protocol.stack.Stats()
+ stats := e.stats.ip
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
srcAddr := h.SourceAddress()
@@ -920,7 +928,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// Multicast addresses must not be used as source addresses in IPv6
// packets or appear in any Routing header.
if header.IsV6MulticastAddress(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.InvalidSourceAddressesReceived.Increment()
return
}
@@ -930,7 +938,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
addressEndpoint.DecRef()
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
+ stats.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -950,9 +958,10 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesInputDropped.Increment()
+ stats.IPTablesInputDropped.Increment()
return
}
@@ -962,7 +971,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -986,7 +995,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -1075,8 +1084,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
for {
it, done, err := it.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -1103,8 +1112,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
switch lastHdr.(type) {
case header.IPv6RawPayloadHeader:
default:
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
}
@@ -1112,8 +1121,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
fragmentPayloadLen := rawPayload.Buf.Size()
if fragmentPayloadLen == 0 {
// Drop the packet as it's marked as a fragment but has no payload.
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
@@ -1126,8 +1135,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// of the fragment, pointing to the Payload Length field of the
// fragment packet.
if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: header.IPv6PayloadLenOffset,
@@ -1147,8 +1156,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// the fragment, pointing to the Fragment Offset field of the fragment
// packet.
if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
_ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: fragmentFieldOffset,
@@ -1173,8 +1182,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt,
)
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
+ stats.MalformedFragmentsReceived.Increment()
return
}
@@ -1194,7 +1203,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -1244,12 +1253,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
- stats.IP.PacketsDelivered.Increment()
+ stats.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
e.handleICMP(pkt, hasFragmentHeader)
} else {
- stats.IP.PacketsDelivered.Increment()
+ stats.PacketsDelivered.Increment()
switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
@@ -1314,7 +1323,7 @@ func (e *endpoint) Close() {
e.mu.addressableEndpointState.Cleanup()
e.mu.Unlock()
- e.protocol.forgetEndpoint(e)
+ e.protocol.forgetEndpoint(e.nic.ID())
}
// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
@@ -1323,7 +1332,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
e.mu.Lock()
@@ -1338,7 +1347,7 @@ func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, p
// solicited-node multicast group and start duplicate address detection.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) {
+func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated)
if err != nil {
return nil, err
@@ -1367,13 +1376,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
}
// RemovePermanentAddress implements stack.AddressableEndpoint.
-func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
return e.removePermanentEndpointLocked(addressEndpoint, true)
@@ -1383,7 +1392,7 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
// it works with a stack.AddressEndpoint.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
unicast := header.IsV6UnicastAddress(addr.Address)
if unicast {
@@ -1408,12 +1417,12 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
}
snmc := header.SolicitedNodeAddr(addr.Address)
+ err := e.leaveGroupLocked(snmc)
// The endpoint may have already left the multicast group.
- if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress {
- return err
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); ok {
+ err = nil
}
-
- return nil
+ return err
}
// hasPermanentAddressLocked returns true if the endpoint has a permanent
@@ -1623,7 +1632,7 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix {
}
// JoinGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) JoinGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.joinGroupLocked(addr)
@@ -1632,9 +1641,9 @@ func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error {
// joinGroupLocked is like JoinGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) joinGroupLocked(addr tcpip.Address) tcpip.Error {
if !header.IsV6MulticastAddress(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
e.mu.mld.joinGroup(addr)
@@ -1642,7 +1651,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
}
// LeaveGroup implements stack.GroupAddressableEndpoint.
-func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) LeaveGroup(addr tcpip.Address) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
return e.leaveGroupLocked(addr)
@@ -1651,7 +1660,7 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
// leaveGroupLocked is like LeaveGroup but with locking requirements.
//
// Precondition: e.mu must be locked.
-func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+func (e *endpoint) leaveGroupLocked(addr tcpip.Address) tcpip.Error {
return e.mu.mld.leaveGroup(addr)
}
@@ -1662,6 +1671,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
return e.mu.mld.isInGroup(addr)
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -1673,7 +1687,9 @@ type protocol struct {
mu struct {
sync.RWMutex
- eps map[*endpoint]struct{}
+ // eps is keyed by NICID to allow protocol methods to retrieve an endpoint
+ // when handling a packet, by looking at which NIC handled the packet.
+ eps map[tcpip.NICID]*endpoint
}
ids []uint32
@@ -1730,37 +1746,42 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
e.mu.mld.init(e)
e.mu.Unlock()
+ stackStats := p.stack.Stats()
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+ e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP)
+ e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V6)
+
p.mu.Lock()
defer p.mu.Unlock()
- p.mu.eps[e] = struct{}{}
+ p.mu.eps[nic.ID()] = e
return e
}
-func (p *protocol) forgetEndpoint(e *endpoint) {
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
- delete(p.mu.eps, e)
+ delete(p.mu.eps, nicID)
}
// SetOption implements NetworkProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
p.SetDefaultTTL(uint8(*v))
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements NetworkProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(p.DefaultTTL())
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -1814,7 +1835,7 @@ func (p *protocol) SetForwarding(v bool) {
return
}
- for ep := range p.mu.eps {
+ for _, ep := range p.mu.eps {
ep.transitionForwarding(v)
}
}
@@ -1823,9 +1844,9 @@ func (p *protocol) SetForwarding(v bool) {
// link-layer payload MTU and the length of every IPv6 header.
// Note that this is different than the Payload Length field of the IPv6 header,
// which includes the length of the extension headers.
-func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error) {
if linkMTU < header.IPv6MinimumMTU {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
// As per RFC 7112 section 5, we should discard packets if their IPv6 header
@@ -1836,7 +1857,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Erro
// bytes ensures that the header chain length does not exceed the IPv6
// minimum MTU.
if networkHeadersLen > header.IPv6MinimumMTU {
- return 0, tcpip.ErrMalformedHeader
+ return 0, &tcpip.ErrMalformedHeader{}
}
networkMTU := linkMTU - uint32(networkHeadersLen)
@@ -1906,7 +1927,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
hashIV: hashIV,
}
p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
- p.mu.eps = make(map[*endpoint]struct{})
+ p.mu.eps = make(map[tcpip.NICID]*endpoint)
p.SetDefaultTTL(DefaultTTL)
return p
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index b65c9d060..8248052a3 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -21,6 +21,7 @@ import (
"io/ioutil"
"math"
"net"
+ "reflect"
"testing"
"github.com/google/go-cmp/cmp"
@@ -370,12 +371,10 @@ func TestAddIpv6Address(t *testing.T) {
t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
}
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
- if addr.Address != test.addr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ if addr, ok := s.GetMainNICAddress(1, header.IPv6ProtocolNumber); !ok {
+ t.Fatalf("got stack.GetMainNICAddress(1, %d) = (_, false), want = (_, true)", header.IPv6ProtocolNumber)
+ } else if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(1_, %d) = (%s, true), want = (%s, true)", header.IPv6ProtocolNumber, addr.Address, test.addr)
}
})
}
@@ -997,8 +996,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}
// Should not have any more UDP packets.
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -1989,8 +1989,9 @@ func TestReceiveIPv6Fragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
+ res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, &tcpip.ErrWouldBlock{})
}
})
}
@@ -2473,11 +2474,11 @@ func TestWriteStats(t *testing.T) {
writers := []struct {
name string
- writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, tcpip.Error)
}{
{
name: "WritePacket",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
nWritten := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
@@ -2489,7 +2490,7 @@ func TestWriteStats(t *testing.T) {
},
}, {
name: "WritePackets",
- writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) {
return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
},
},
@@ -2499,7 +2500,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv6MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
for i := 0; i < nPackets; i++ {
@@ -2574,7 +2575,7 @@ func (*limitedMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) (bool, bool) {
if lm.limit == 0 {
return true, false
}
@@ -2582,30 +2583,44 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool,
return false, false
}
+func knownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
func TestClearEndpointFromProtocolOnClose(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
- {
- proto.mu.Lock()
- _, hasEP := proto.mu.eps[ep]
- proto.mu.Unlock()
- if !hasEP {
- t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep)
- }
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ var nicIDs []tcpip.NICID
+
+ proto.mu.Lock()
+ foundEP, hasEndpointBeforeClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if !hasEndpointBeforeClose {
+ t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("found an incorrect endpoint mapped to nic id %d", nic.ID())
}
ep.Close()
- {
- proto.mu.Lock()
- _, hasEP := proto.mu.eps[ep]
- proto.mu.Unlock()
- if hasEP {
- t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep)
- }
+ proto.mu.Lock()
+ _, hasEndpointAfterClose := proto.mu.eps[nic.ID()]
+ nicIDs = knownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEndpointAfterClose {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
}
}
@@ -2819,8 +2834,8 @@ func TestFragmentationErrors(t *testing.T) {
payloadSize int
allowPackets int
outgoingErrors int
- mockError *tcpip.Error
- wantError *tcpip.Error
+ mockError tcpip.Error
+ wantError tcpip.Error
}{
{
description: "No frag",
@@ -2829,8 +2844,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 0,
outgoingErrors: 1,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on first frag",
@@ -2839,8 +2854,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 0,
outgoingErrors: 3,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error on second frag",
@@ -2849,8 +2864,8 @@ func TestFragmentationErrors(t *testing.T) {
transHdrLen: 0,
allowPackets: 1,
outgoingErrors: 2,
- mockError: tcpip.ErrAborted,
- wantError: tcpip.ErrAborted,
+ mockError: &tcpip.ErrAborted{},
+ wantError: &tcpip.ErrAborted{},
},
{
description: "Error when MTU is smaller than transport header",
@@ -2860,7 +2875,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrMessageTooLong,
+ wantError: &tcpip.ErrMessageTooLong{},
},
{
description: "Error when MTU is smaller than IPv6 minimum MTU",
@@ -2870,7 +2885,7 @@ func TestFragmentationErrors(t *testing.T) {
allowPackets: 0,
outgoingErrors: 1,
mockError: nil,
- wantError: tcpip.ErrInvalidEndpointState,
+ wantError: &tcpip.ErrInvalidEndpointState{},
},
}
@@ -2884,8 +2899,8 @@ func TestFragmentationErrors(t *testing.T) {
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != ft.wantError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
+ if diff := cmp.Diff(ft.wantError, err); diff != "" {
+ t.Errorf("unexpected error from WritePacket(_, _, _), (-want, +got):\n%s", diff)
}
if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
@@ -3053,3 +3068,22 @@ func TestForwarding(t *testing.T) {
})
}
}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // supposed to be bound.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index ec54d88cc..2cc0dfebd 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -68,14 +68,14 @@ func (mld *mldState) Enabled() bool {
// SendReport implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) {
+func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport)
}
// SendLeave implements ip.MulticastGroupProtocol.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
_, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
@@ -112,7 +112,7 @@ func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
// joinGroup handles joining a new group and sending and scheduling the required
// messages.
//
-// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+// If the group is already joined, returns *tcpip.ErrDuplicateAddress.
//
// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
@@ -131,13 +131,13 @@ func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
// required.
//
// Precondition: mld.ep.mu must be locked.
-func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
+func (mld *mldState) leaveGroup(groupAddress tcpip.Address) tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
// softLeaveAll leaves all groups from the perspective of MLD, but remains
@@ -166,14 +166,14 @@ func (mld *mldState) sendQueuedReports() {
// writePacket assembles and sends an MLD packet.
//
// Precondition: mld.ep.mu must be read locked.
-func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) {
- sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- var mldStat *tcpip.StatCounter
+func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, tcpip.Error) {
+ sentStats := mld.ep.stats.icmp.packetsSent
+ var mldStat tcpip.MultiCounterStat
switch mldType {
case header.ICMPv6MulticastListenerReport:
- mldStat = sentStats.MulticastListenerReport
+ mldStat = sentStats.multicastListenerReport
case header.ICMPv6MulticastListenerDone:
- mldStat = sentStats.MulticastListenerDone
+ mldStat = sentStats.multicastListenerDone
default:
panic(fmt.Sprintf("unrecognized mld type = %d", mldType))
}
@@ -249,14 +249,14 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp
Data: buffer.View(icmp).ToVectorisedView(),
})
- if err := mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.MLDHopLimit,
}, extensionHeaders); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sentStats.Dropped.Increment()
+ sentStats.dropped.Increment()
return false, err
}
mldStat.Increment()
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 1d8fee50b..d7dde1767 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -241,7 +241,7 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error)
// OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
@@ -614,10 +614,10 @@ type slaacPrefixState struct {
// tentative.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
- return tcpip.ErrAddressFamilyNotSupported
+ return &tcpip.ErrAddressFamilyNotSupported{}
}
if addressEndpoint.GetKind() != stack.PermanentTentative {
@@ -666,7 +666,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
dadDone := remaining == 0
- var err *tcpip.Error
+ var err tcpip.Error
if !dadDone {
err = ndp.sendDADPacket(addr, addressEndpoint)
}
@@ -717,7 +717,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// addr.
//
// addr must be a tentative IPv6 address on ndp's IPv6 endpoint.
-func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -731,8 +731,8 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
Data: buffer.View(icmp).ToVectorisedView(),
})
- sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
+ sent := ndp.ep.stats.icmp.packetsSent
+ if err := addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
@@ -740,10 +740,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
}
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
- sent.NeighborSolicit.Increment()
+ sent.neighborSolicit.Increment()
+
return nil
}
@@ -1855,20 +1856,20 @@ func (ndp *ndpState) startSolicitingRouters() {
Data: buffer.View(icmpData).ToVectorisedView(),
})
- sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent
- if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ sent := ndp.ep.stats.icmp.packetsSent
+ if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
- sent.RouterSolicit.Increment()
+ sent.routerSolicit.Increment()
remaining--
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index b1a5a5510..1d38b8b05 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -90,7 +90,7 @@ type testNDPDispatcher struct {
addr tcpip.Address
}
-func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
}
func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
@@ -162,10 +162,15 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
}
}
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+type linkResolutionResult struct {
+ linkAddr tcpip.LinkAddress
+ ok bool
+}
+
+// TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
-func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+func TestNeighborSolicitationWithSourceLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -231,45 +236,40 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ wantInvalid := uint64(0)
+ wantSucccess := true
+ if len(test.expectedLinkAddr) == 0 {
+ wantInvalid = 1
+ wantSucccess = false
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
}
} else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ if err != nil {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
}
+ }
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
+ t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
}
})
}
}
-// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
+// TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests
// that receiving a valid NDP NS message with the Source Link Layer Address
// option results in a new entry in the link address cache for the sender of
// the message.
-func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
+func TestNeighborSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -382,7 +382,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
}
}
-func TestNeighorSolicitationResponse(t *testing.T) {
+func TestNeighborSolicitationResponse(t *testing.T) {
const nicID = 1
nicAddr := lladdr0
remoteAddr := lladdr1
@@ -640,18 +640,12 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatal("expected an NDP NS response")
}
- if p.Route.LocalAddress != nicAddr {
- t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr)
- }
- if p.Route.LocalLinkAddress != nicLinkAddr {
- t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr)
- }
respNSDst := header.SolicitedNodeAddr(test.nsSrc)
- if p.Route.RemoteAddress != respNSDst {
- t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst)
- }
- if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want {
- t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ var want stack.RouteInfo
+ want.NetProto = ProtocolNumber
+ want.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(respNSDst)
+ if diff := cmp.Diff(want, p.Route, cmp.AllowUnexported(want)); diff != "" {
+ t.Errorf("route info mismatch (-want +got):\n%s", diff)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
@@ -727,10 +721,10 @@ func TestNeighorSolicitationResponse(t *testing.T) {
}
}
-// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// TestNeighborAdvertisementWithTargetLinkLayerOption tests that receiving a
// valid NDP NA message with the Target Link Layer Address option results in a
// new entry in the link address cache for the target of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -803,45 +797,40 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
+ wantInvalid := uint64(0)
+ wantSucccess := true
+ if len(test.expectedLinkAddr) == 0 {
+ wantInvalid = 1
+ wantSucccess = false
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, &tcpip.ErrWouldBlock{})
}
} else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ if err != nil {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = nil", nicID, lladdr1, lladdr0, ProtocolNumber, err)
}
+ }
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
+ t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
}
})
}
}
-// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
+// TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests
// that receiving a valid NDP NA message with the Target Link Layer Address
// option does not result in a new entry in the neighbor cache for the target
// of the message.
-func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
+func TestNeighborAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) {
const nicID = 1
tests := []struct {
@@ -1183,6 +1172,118 @@ func TestNDPValidation(t *testing.T) {
}
+// TestNeighborAdvertisementValidation tests that the NIC validates received
+// Neighbor Advertisements.
+//
+// In particular, if the IP Destination Address is a multicast address, and the
+// Solicited flag is not zero, the Neighbor Advertisement is invalid and should
+// be discarded.
+func TestNeighborAdvertisementValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ ipDstAddr tcpip.Address
+ solicitedFlag bool
+ valid bool
+ }{
+ {
+ name: "Multicast IP destination address with Solicited flag set",
+ ipDstAddr: header.IPv6AllNodesMulticastAddress,
+ solicitedFlag: true,
+ valid: false,
+ },
+ {
+ name: "Multicast IP destination address with Solicited flag unset",
+ ipDstAddr: header.IPv6AllNodesMulticastAddress,
+ solicitedFlag: false,
+ valid: true,
+ },
+ {
+ name: "Unicast IP destination address with Solicited flag set",
+ ipDstAddr: lladdr0,
+ solicitedFlag: true,
+ valid: true,
+ },
+ {
+ name: "Unicast IP destination address with Solicited flag unset",
+ ipDstAddr: lladdr0,
+ solicitedFlag: false,
+ valid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ UseNeighborCache: true,
+ })
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr0)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.MessageBody())
+ na.SetTargetAddress(lladdr1)
+ na.SetSolicitedFlag(test.solicitedFlag)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, test.ipDstAddr, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ TransportProtocol: header.ICMPv6ProtocolNumber,
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: test.ipDstAddr,
+ })
+
+ stats := s.Stats().ICMP.V6.PacketsReceived
+ invalid := stats.Invalid
+ rxNA := stats.NeighborAdvert
+
+ if got := rxNA.Value(); got != 0 {
+ t.Fatalf("got rxNA = %d, want = 0", got)
+ }
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ if got := rxNA.Value(); got != 1 {
+ t.Fatalf("got rxNA = %d, want = 1", got)
+ }
+ var wantInvalid uint64 = 1
+ if test.valid {
+ wantInvalid = 0
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Fatalf("got invalid = %d, want = %d", got, wantInvalid)
+ }
+ // As per RFC 4861 section 7.2.5:
+ // When a valid Neighbor Advertisement is received ...
+ // If no entry exists, the advertisement SHOULD be silently discarded.
+ // There is no need to create an entry if none exists, since the
+ // recipient has apparently not initiated any communication with the
+ // target.
+ if neighbors, err := s.Neighbors(nicID); err != nil {
+ t.Fatalf("s.Neighbors(%d): %s", nicID, err)
+ } else if len(neighbors) != 0 {
+ t.Fatalf("got len(neighbors) = %d, want = 0; neighbors = %#v", len(neighbors), neighbors)
+ }
+ })
+ }
+}
+
// TestRouterAdvertValidation tests that when the NIC is configured to handle
// NDP Router Advertisement packets, it validates the Router Advertisement
// properly before handling them.
diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go
new file mode 100644
index 000000000..0839be3cd
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/stats.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to the IPv6 protocol family.
+type Stats struct {
+ // IP holds IPv6 statistics.
+ IP tcpip.IPStats
+
+ // ICMP holds ICMPv6 statistics.
+ ICMP tcpip.ICMPv6Stats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*Stats) IsNetworkEndpointStats() {}
+
+// IPStats implements stack.IPNetworkEndointStats
+func (s *Stats) IPStats() *tcpip.IPStats {
+ return &s.IP
+}
+
+type sharedStats struct {
+ localStats Stats
+ ip ip.MultiCounterIPStats
+ icmp multiCounterICMPv6Stats
+}
+
+// LINT.IfChange(multiCounterICMPv6PacketStats)
+
+type multiCounterICMPv6PacketStats struct {
+ echoRequest tcpip.MultiCounterStat
+ echoReply tcpip.MultiCounterStat
+ dstUnreachable tcpip.MultiCounterStat
+ packetTooBig tcpip.MultiCounterStat
+ timeExceeded tcpip.MultiCounterStat
+ paramProblem tcpip.MultiCounterStat
+ routerSolicit tcpip.MultiCounterStat
+ routerAdvert tcpip.MultiCounterStat
+ neighborSolicit tcpip.MultiCounterStat
+ neighborAdvert tcpip.MultiCounterStat
+ redirectMsg tcpip.MultiCounterStat
+ multicastListenerQuery tcpip.MultiCounterStat
+ multicastListenerReport tcpip.MultiCounterStat
+ multicastListenerDone tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv6PacketStats) init(a, b *tcpip.ICMPv6PacketStats) {
+ m.echoRequest.Init(a.EchoRequest, b.EchoRequest)
+ m.echoReply.Init(a.EchoReply, b.EchoReply)
+ m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
+ m.packetTooBig.Init(a.PacketTooBig, b.PacketTooBig)
+ m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
+ m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
+ m.routerSolicit.Init(a.RouterSolicit, b.RouterSolicit)
+ m.routerAdvert.Init(a.RouterAdvert, b.RouterAdvert)
+ m.neighborSolicit.Init(a.NeighborSolicit, b.NeighborSolicit)
+ m.neighborAdvert.Init(a.NeighborAdvert, b.NeighborAdvert)
+ m.redirectMsg.Init(a.RedirectMsg, b.RedirectMsg)
+ m.multicastListenerQuery.Init(a.MulticastListenerQuery, b.MulticastListenerQuery)
+ m.multicastListenerReport.Init(a.MulticastListenerReport, b.MulticastListenerReport)
+ m.multicastListenerDone.Init(a.MulticastListenerDone, b.MulticastListenerDone)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6PacketStats)
+
+// LINT.IfChange(multiCounterICMPv6SentPacketStats)
+
+type multiCounterICMPv6SentPacketStats struct {
+ multiCounterICMPv6PacketStats
+ dropped tcpip.MultiCounterStat
+ rateLimited tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv6SentPacketStats) init(a, b *tcpip.ICMPv6SentPacketStats) {
+ m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+ m.rateLimited.Init(a.RateLimited, b.RateLimited)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6SentPacketStats)
+
+// LINT.IfChange(multiCounterICMPv6ReceivedPacketStats)
+
+type multiCounterICMPv6ReceivedPacketStats struct {
+ multiCounterICMPv6PacketStats
+ unrecognized tcpip.MultiCounterStat
+ invalid tcpip.MultiCounterStat
+ routerOnlyPacketsDroppedByHost tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv6ReceivedPacketStats) init(a, b *tcpip.ICMPv6ReceivedPacketStats) {
+ m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats)
+ m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
+ m.invalid.Init(a.Invalid, b.Invalid)
+ m.routerOnlyPacketsDroppedByHost.Init(a.RouterOnlyPacketsDroppedByHost, b.RouterOnlyPacketsDroppedByHost)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6ReceivedPacketStats)
+
+// LINT.IfChange(multiCounterICMPv6Stats)
+
+type multiCounterICMPv6Stats struct {
+ packetsSent multiCounterICMPv6SentPacketStats
+ packetsReceived multiCounterICMPv6ReceivedPacketStats
+}
+
+func (m *multiCounterICMPv6Stats) init(a, b *tcpip.ICMPv6Stats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv6Stats)
diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD
index d0ffc299a..bd62c4482 100644
--- a/pkg/tcpip/network/testutil/BUILD
+++ b/pkg/tcpip/network/testutil/BUILD
@@ -6,8 +6,10 @@ go_library(
name = "testutil",
srcs = [
"testutil.go",
+ "testutil_unsafe.go",
],
visibility = [
+ "//pkg/tcpip/network/arp:__pkg__",
"//pkg/tcpip/network/fragmentation:__pkg__",
"//pkg/tcpip/network/ipv4:__pkg__",
"//pkg/tcpip/network/ipv6:__pkg__",
diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go
index 3af44991f..f5fa77b65 100644
--- a/pkg/tcpip/network/testutil/testutil.go
+++ b/pkg/tcpip/network/testutil/testutil.go
@@ -19,6 +19,8 @@ package testutil
import (
"fmt"
"math/rand"
+ "reflect"
+ "strings"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -33,7 +35,7 @@ type MockLinkEndpoint struct {
WrittenPackets []*stack.PacketBuffer
mtu uint32
- err *tcpip.Error
+ err tcpip.Error
allowPackets int
}
@@ -41,7 +43,7 @@ type MockLinkEndpoint struct {
//
// err is the error that will be returned once allowPackets packets are written
// to the endpoint.
-func NewMockLinkEndpoint(mtu uint32, err *tcpip.Error, allowPackets int) *MockLinkEndpoint {
+func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
return &MockLinkEndpoint{
mtu: mtu,
err: err,
@@ -62,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
// WritePacket implements LinkEndpoint.WritePacket.
-func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error {
if ep.allowPackets == 0 {
return ep.err
}
@@ -72,7 +74,7 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
var n int
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -127,3 +129,69 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi
}
return pkt
}
+
+func checkFieldCounts(ref, multi reflect.Value) error {
+ refTypeName := ref.Type().Name()
+ multiTypeName := multi.Type().Name()
+ refNumField := ref.NumField()
+ multiNumField := multi.NumField()
+
+ if refNumField != multiNumField {
+ return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName)
+ }
+
+ return nil
+}
+
+func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error {
+ s, ok := ref.Addr().Interface().(**tcpip.StatCounter)
+ if !ok {
+ return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name())
+ }
+
+ // The field names are expected to match (case insensitive).
+ if !strings.EqualFold(refName, multiName) {
+ return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName)
+ }
+
+ base := (*s).Value()
+ m.Increment()
+ if (*s).Value() != base+1 {
+ return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName)
+ }
+
+ return nil
+}
+
+// ValidateMultiCounterStats verifies that every counter stored in multi is
+// correctly tracking its counterpart in the given counters.
+func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error {
+ for _, c := range counters {
+ if err := checkFieldCounts(c, multi); err != nil {
+ return err
+ }
+ }
+
+ for i := 0; i < multi.NumField(); i++ {
+ multiName := multi.Type().Field(i).Name
+ multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i))
+
+ if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok {
+ for _, c := range counters {
+ if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil {
+ return err
+ }
+ }
+ } else {
+ var countersNextField []reflect.Value
+ for _, c := range counters {
+ countersNextField = append(countersNextField, c.Field(i))
+ }
+ if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/testutil/testutil_unsafe.go
new file mode 100644
index 000000000..5ff764800
--- /dev/null
+++ b/pkg/tcpip/network/testutil/testutil_unsafe.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafeExposeUnexportedFields takes a Value and returns a version of it in
+// which even unexported fields can be read and written.
+func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value {
+ return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem()
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 2bad05a2e..57abec5c9 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -18,5 +18,6 @@ go_test(
library = ":ports",
deps = [
"//pkg/tcpip",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index d87193650..11dbdbbcf 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -329,7 +329,7 @@ func NewPortManager() *PortManager {
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
offset := uint32(rand.Int31n(numEphemeralPorts))
return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
}
@@ -348,7 +348,7 @@ func (s *PortManager) incPortHint() {
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
if err == nil {
s.incPortHint()
@@ -361,7 +361,7 @@ func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uin
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
for i := uint32(0); i < count; i++ {
port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
@@ -374,7 +374,7 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
}
}
- return 0, tcpip.ErrNoPortAvailable
+ return 0, &tcpip.ErrNoPortAvailable{}
}
// IsPortAvailable tests if the given port is available on all given protocols.
@@ -404,7 +404,7 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// An optional testPort closure can be passed in which if provided will be used
// to test if the picked port can be used. The function should return true if
// the port is safe to use, false otherwise.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -414,17 +414,17 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// protocols.
if port != 0 {
if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
- return 0, tcpip.ErrPortInUse
+ return 0, &tcpip.ErrPortInUse{}
}
if testPort != nil && !testPort(port) {
s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
- return 0, tcpip.ErrPortInUse
+ return 0, &tcpip.ErrPortInUse{}
}
return port, nil
}
// A port wasn't specified, so try to find one.
- return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
return false, nil
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 4bc949fd8..e70fbb72b 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -18,6 +18,7 @@ import (
"math/rand"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -32,7 +33,7 @@ const (
type portReserveTestAction struct {
port uint16
ip tcpip.Address
- want *tcpip.Error
+ want tcpip.Error
flags Flags
release bool
device tcpip.NICID
@@ -50,19 +51,19 @@ func TestPortReservation(t *testing.T) {
{port: 80, ip: fakeIPAddress, want: nil},
{port: 80, ip: fakeIPAddress1, want: nil},
/* N.B. Order of tests matters! */
- {port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
- {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
+ {port: 80, ip: anyIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 80, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}},
},
},
{
tname: "bind to inaddr any",
actions: []portReserveTestAction{
{port: 22, ip: anyIPAddress, want: nil},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
/* release fakeIPAddress, but anyIPAddress is still inuse */
{port: 22, ip: fakeIPAddress, release: true},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
+ {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 22, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}, flags: Flags{LoadBalanced: true}},
/* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
{port: 22, ip: anyIPAddress, want: nil, release: true},
{port: 22, ip: fakeIPAddress, want: nil},
@@ -80,8 +81,8 @@ func TestPortReservation(t *testing.T) {
{port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 25, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
+ {port: 25, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
@@ -91,14 +92,14 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
- {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
{port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
@@ -107,7 +108,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind twice with device fails",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 3, want: nil},
- {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 3, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind to device",
@@ -119,50 +120,50 @@ func TestPortReservation(t *testing.T) {
tname: "bind to device and then without device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind without device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, want: nil},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseport",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "binding with reuseport and device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 999, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "mixing reuseport and not reuseport by binding to device",
@@ -177,14 +178,14 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind and release",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
// Release the bind to device 0 and try again.
@@ -195,7 +196,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind twice with reuseport once",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "release an unreserved device",
@@ -213,16 +214,16 @@ func TestPortReservation(t *testing.T) {
tname: "bind with reuseaddr",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: &tcpip.ErrPortInUse{}},
{port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
},
}, {
tname: "bind twice with reuseaddr once",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseaddr and reuseport",
@@ -236,14 +237,14 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseaddr and reuseport, and then reuseport",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
@@ -264,14 +265,14 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind with reuseport, and then reuseaddr and reuseport",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr",
@@ -283,7 +284,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind tuple with reuseaddr, and then wildcard",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
- {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr",
@@ -295,7 +296,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind tuple with reuseaddr, and then wildcard",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind two tuples with reuseaddr",
@@ -313,7 +314,7 @@ func TestPortReservation(t *testing.T) {
tname: "bind wildcard, and then tuple with reuseaddr",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil},
- {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: &tcpip.ErrPortInUse{}},
},
}, {
tname: "bind wildcard twice with reuseaddr",
@@ -333,8 +334,8 @@ func TestPortReservation(t *testing.T) {
continue
}
gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */)
- if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want)
+ if diff := cmp.Diff(test.want, err); diff != "" {
+ t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -345,30 +346,29 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- customErr := &tcpip.Error{}
for _, test := range []struct {
name string
- f func(port uint16) (bool, *tcpip.Error)
- wantErr *tcpip.Error
+ f func(port uint16) (bool, tcpip.Error)
+ wantErr tcpip.Error
wantPort uint16
}{
{
name: "no-port-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
{
name: "port-tester-error",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, customErr
+ f: func(port uint16) (bool, tcpip.Error) {
+ return false, &tcpip.ErrBadBuffer{}
},
- wantErr: customErr,
+ wantErr: &tcpip.ErrBadBuffer{},
},
{
name: "only-port-16042-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port == FirstEphemeral+42 {
return true, nil
}
@@ -378,49 +378,52 @@ func TestPickEphemeralPort(t *testing.T) {
},
{
name: "only-port-under-16000-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port < FirstEphemeral {
return true, nil
}
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
- if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
- t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ port, err := pm.PickEphemeralPort(test.f)
+ if diff := cmp.Diff(test.wantErr, err); diff != "" {
+ t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
+ }
+ if port != test.wantPort {
+ t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort)
}
})
}
}
func TestPickEphemeralPortStable(t *testing.T) {
- customErr := &tcpip.Error{}
for _, test := range []struct {
name string
- f func(port uint16) (bool, *tcpip.Error)
- wantErr *tcpip.Error
+ f func(port uint16) (bool, tcpip.Error)
+ wantErr tcpip.Error
wantPort uint16
}{
{
name: "no-port-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
{
name: "port-tester-error",
- f: func(port uint16) (bool, *tcpip.Error) {
- return false, customErr
+ f: func(port uint16) (bool, tcpip.Error) {
+ return false, &tcpip.ErrBadBuffer{}
},
- wantErr: customErr,
+ wantErr: &tcpip.ErrBadBuffer{},
},
{
name: "only-port-16042-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port == FirstEphemeral+42 {
return true, nil
}
@@ -430,20 +433,24 @@ func TestPickEphemeralPortStable(t *testing.T) {
},
{
name: "only-port-under-16000-available",
- f: func(port uint16) (bool, *tcpip.Error) {
+ f: func(port uint16) (bool, tcpip.Error) {
if port < FirstEphemeral {
return true, nil
}
return false, nil
},
- wantErr: tcpip.ErrNoPortAvailable,
+ wantErr: &tcpip.ErrNoPortAvailable{},
},
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
- if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
- t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ port, err := pm.PickEphemeralPortStable(portOffset, test.f)
+ if diff := cmp.Diff(test.wantErr, err); diff != "" {
+ t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
+ }
+ if port != test.wantPort {
+ t.Errorf("got PickEphemeralPort(..) = (%d, nil); want (%d, nil)", port, test.wantPort)
}
})
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index cf0a5fefe..db9b91815 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -8,7 +8,6 @@ go_binary(
visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/rawfile",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 3b4f900e3..856ea998d 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -41,7 +41,7 @@
package main
import (
- "bufio"
+ "bytes"
"fmt"
"log"
"math/rand"
@@ -51,7 +51,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
@@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) {
close(ch)
}()
- r := bufio.NewReader(os.Stdin)
- for {
- v := buffer.NewView(1024)
- n, err := r.Read(v)
- if err != nil {
- return
- }
-
- v.CapLength(n)
- for len(v) > 0 {
- n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- if err != nil {
- fmt.Println("Write failed:", err)
- return
+ var b bytes.Buffer
+ if err := func() error {
+ for {
+ if _, err := b.ReadFrom(os.Stdin); err != nil {
+ return fmt.Errorf("b.ReadFrom failed: %w", err)
}
- v.TrimFront(int(n))
+ for b.Len() != 0 {
+ if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil {
+ return fmt.Errorf("ep.Write failed: %s", err)
+ }
+ }
}
+ }(); err != nil {
+ fmt.Println(err)
}
}
@@ -179,7 +175,7 @@ func main() {
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventOut)
terr := ep.Connect(remote)
- if terr == tcpip.ErrConnectStarted {
+ if _, ok := terr.(*tcpip.ErrConnectStarted); ok {
fmt.Println("Connect is pending...")
<-notifyCh
terr = ep.LastError()
@@ -202,11 +198,11 @@ func main() {
for {
_, err := ep.Read(os.Stdout, tcpip.ReadOptions{})
if err != nil {
- if err == tcpip.ErrClosedForReceive {
+ if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
break
}
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-notifyCh
continue
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 3ac562756..9b23df3a9 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,6 +20,7 @@
package main
import (
+ "bytes"
"flag"
"io"
"log"
@@ -50,7 +51,7 @@ type endpointWriter struct {
}
type tcpipError struct {
- inner *tcpip.Error
+ inner tcpip.Error
}
func (e *tcpipError) Error() string {
@@ -58,7 +59,9 @@ func (e *tcpipError) Error() string {
}
func (e *endpointWriter) Write(p []byte) (int, error) {
- n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(p)
+ n, err := e.ep.Write(&r, tcpip.WriteOptions{})
if err != nil {
return int(n), &tcpipError{
inner: err,
@@ -86,7 +89,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
for {
_, err := ep.Read(&w, tcpip.ReadOptions{})
if err != nil {
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-notifyCh
continue
}
@@ -214,7 +217,7 @@ func main() {
for {
n, wq, err := ep.Accept(nil)
if err != nil {
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-notifyCh
continue
}
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index f3ad40fdf..019d6a63c 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -15,11 +15,16 @@
package tcpip
import (
+ "math"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
)
+// PacketOverheadFactor is used to multiply the value provided by the user on a
+// SetSockOpt for setting the send/receive buffer sizes sockets.
+const PacketOverheadFactor = 2
+
// SocketOptionsHandler holds methods that help define endpoint specific
// behavior for socket level socket options. These must be implemented by
// endpoints to get notified when socket level options are set.
@@ -41,13 +46,19 @@ type SocketOptionsHandler interface {
OnCorkOptionSet(v bool)
// LastError is invoked when SO_ERROR is read for an endpoint.
- LastError() *Error
+ LastError() Error
// UpdateLastError updates the endpoint specific last error field.
- UpdateLastError(err *Error)
+ UpdateLastError(err Error)
// HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE.
HasNIC(v int32) bool
+
+ // GetSendBufferSize is invoked to get the SO_SNDBUFSIZE.
+ GetSendBufferSize() (int64, Error)
+
+ // IsUnixSocket is invoked to check if the socket is of unix domain.
+ IsUnixSocket() bool
}
// DefaultSocketOptionsHandler is an embeddable type that implements no-op
@@ -72,18 +83,39 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {}
func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {}
// LastError implements SocketOptionsHandler.LastError.
-func (*DefaultSocketOptionsHandler) LastError() *Error {
+func (*DefaultSocketOptionsHandler) LastError() Error {
return nil
}
// UpdateLastError implements SocketOptionsHandler.UpdateLastError.
-func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {}
+func (*DefaultSocketOptionsHandler) UpdateLastError(Error) {}
// HasNIC implements SocketOptionsHandler.HasNIC.
func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
return false
}
+// GetSendBufferSize implements SocketOptionsHandler.GetSendBufferSize.
+func (*DefaultSocketOptionsHandler) GetSendBufferSize() (int64, Error) {
+ return 0, nil
+}
+
+// IsUnixSocket implements SocketOptionsHandler.IsUnixSocket.
+func (*DefaultSocketOptionsHandler) IsUnixSocket() bool {
+ return false
+}
+
+// StackHandler holds methods to access the stack options. These must be
+// implemented by the stack.
+type StackHandler interface {
+ // Option allows retrieving stack wide options.
+ Option(option interface{}) Error
+
+ // TransportProtocolOption allows retrieving individual protocol level
+ // option values.
+ TransportProtocolOption(proto TransportProtocolNumber, option GettableTransportProtocolOption) Error
+}
+
// SocketOptions contains all the variables which store values for SOL_SOCKET,
// SOL_IP, SOL_IPV6 and SOL_TCP level options.
//
@@ -91,6 +123,9 @@ func (*DefaultSocketOptionsHandler) HasNIC(int32) bool {
type SocketOptions struct {
handler SocketOptionsHandler
+ // StackHandler is initialized at the creation time and will not change.
+ stackHandler StackHandler `state:"manual"`
+
// These fields are accessed and modified using atomic operations.
// broadcastEnabled determines whether datagram sockets are allowed to
@@ -170,6 +205,14 @@ type SocketOptions struct {
// bindToDevice determines the device to which the socket is bound.
bindToDevice int32
+ // getSendBufferLimits provides the handler to get the min, default and
+ // max size for send buffer. It is initialized at the creation time and
+ // will not change.
+ getSendBufferLimits GetSendBufferLimits `state:"manual"`
+
+ // sendBufferSize determines the send buffer size for this socket.
+ sendBufferSize int64
+
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -180,8 +223,10 @@ type SocketOptions struct {
// InitHandler initializes the handler. This must be called before using the
// socket options utility.
-func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) {
+func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) {
so.handler = handler
+ so.stackHandler = stack
+ so.getSendBufferLimits = getSendBufferLimits
}
func storeAtomicBool(addr *uint32, v bool) {
@@ -193,7 +238,7 @@ func storeAtomicBool(addr *uint32, v bool) {
}
// SetLastError sets the last error for a socket.
-func (so *SocketOptions) SetLastError(err *Error) {
+func (so *SocketOptions) SetLastError(err Error) {
so.handler.UpdateLastError(err)
}
@@ -378,7 +423,7 @@ func (so *SocketOptions) SetRecvError(v bool) {
}
// GetLastError gets value for SO_ERROR option.
-func (so *SocketOptions) GetLastError() *Error {
+func (so *SocketOptions) GetLastError() Error {
return so.handler.LastError()
}
@@ -435,7 +480,7 @@ type SockError struct {
sockErrorEntry
// Err is the error caused by the errant packet.
- Err *Error
+ Err Error
// ErrOrigin indicates the error origin.
ErrOrigin SockErrOrigin
// ErrType is the type in the ICMP header.
@@ -493,7 +538,7 @@ func (so *SocketOptions) QueueErr(err *SockError) {
}
// QueueLocalErr queues a local error onto the local queue.
-func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
+func (so *SocketOptions) QueueLocalErr(err Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) {
so.QueueErr(&SockError{
Err: err,
ErrOrigin: SockExtErrorOriginLocal,
@@ -510,11 +555,52 @@ func (so *SocketOptions) GetBindToDevice() int32 {
}
// SetBindToDevice sets value for SO_BINDTODEVICE option.
-func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error {
+func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
if !so.handler.HasNIC(bindToDevice) {
- return ErrUnknownDevice
+ return &ErrUnknownDevice{}
}
atomic.StoreInt32(&so.bindToDevice, bindToDevice)
return nil
}
+
+// GetSendBufferSize gets value for SO_SNDBUF option.
+func (so *SocketOptions) GetSendBufferSize() (int64, Error) {
+ if so.handler.IsUnixSocket() {
+ return so.handler.GetSendBufferSize()
+ }
+ return atomic.LoadInt64(&so.sendBufferSize), nil
+}
+
+// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
+// stack handler should be invoked to set the send buffer size.
+func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
+ if so.handler.IsUnixSocket() {
+ return
+ }
+
+ v := sendBufferSize
+ if notify {
+ // TODO(b/176170271): Notify waiters after size has grown.
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ ss := so.getSendBufferLimits(so.stackHandler)
+ min := int64(ss.Min)
+ max := int64(ss.Max)
+ // Validate the send buffer size with min and max values.
+ // Multiply it by factor of 2.
+ if v > max {
+ v = max
+ }
+
+ if v < math.MaxInt32/PacketOverheadFactor {
+ v *= PacketOverheadFactor
+ if v < min {
+ v = min
+ }
+ } else {
+ v = math.MaxInt32
+ }
+ }
+ atomic.StoreInt64(&so.sendBufferSize, v)
+}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index cd423bf71..e5590ecc0 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -117,7 +117,7 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS
}
// AddAndAcquirePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */)
@@ -143,10 +143,10 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr
// AddAndAcquireTemporaryAddress adds a temporary address.
//
-// Returns tcpip.ErrDuplicateAddress if the address exists.
+// Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// The temporary address's endpoint is acquired and returned.
-func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) {
+func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) {
a.mu.Lock()
defer a.mu.Unlock()
ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */)
@@ -176,11 +176,11 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr
// If the addressable endpoint already has the address in a non-permanent state,
// and addAndAcquireAddressLocked is adding a permanent address, that address is
// promoted in place and its properties set to the properties provided. If the
-// address already exists in any other state, then tcpip.ErrDuplicateAddress is
+// address already exists in any other state, then *tcpip.ErrDuplicateAddress is
// returned, regardless the kind of address that is being added.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) {
+func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) {
// attemptAddToPrimary is false when the address is already in the primary
// address list.
attemptAddToPrimary := true
@@ -190,7 +190,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// We are adding a non-permanent address but the address exists. No need
// to go any further since we can only promote existing temporary/expired
// addresses to permanent.
- return nil, tcpip.ErrDuplicateAddress
+ return nil, &tcpip.ErrDuplicateAddress{}
}
addrState.mu.Lock()
@@ -198,7 +198,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState.mu.Unlock()
// We are adding a permanent address but a permanent address already
// exists.
- return nil, tcpip.ErrDuplicateAddress
+ return nil, &tcpip.ErrDuplicateAddress{}
}
if addrState.mu.refs == 0 {
@@ -293,7 +293,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
}
// RemovePermanentAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error {
+func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
a.mu.Lock()
defer a.mu.Unlock()
return a.removePermanentAddressLocked(addr)
@@ -303,10 +303,10 @@ func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *t
// requirements.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) tcpip.Error {
addrState, ok := a.mu.endpoints[addr]
if !ok {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
return a.removePermanentEndpointLocked(addrState)
@@ -314,10 +314,10 @@ func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Addre
// RemovePermanentEndpoint removes the passed endpoint if it is associated with
// a and permanent.
-func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error {
+func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) tcpip.Error {
addrState, ok := ep.(*addressState)
if !ok || addrState.addressableEndpointState != a {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
a.mu.Lock()
@@ -329,9 +329,9 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *
// requirements.
//
// Precondition: a.mu must be write locked.
-func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error {
+func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) tcpip.Error {
if !addrState.GetKind().IsPermanent() {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
addrState.SetKind(PermanentExpired)
@@ -574,9 +574,11 @@ func (a *AddressableEndpointState) Cleanup() {
defer a.mu.Unlock()
for _, ep := range a.mu.endpoints {
- // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is
+ // removePermanentEndpointLocked returns *tcpip.ErrBadLocalAddress if ep is
// not a permanent address.
- if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress {
+ switch err := a.removePermanentEndpointLocked(ep); err.(type) {
+ case nil, *tcpip.ErrBadLocalAddress:
+ default:
panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err))
}
}
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 5e649cca6..54617f2e6 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -198,15 +198,15 @@ type bucket struct {
// TCP header.
//
// Preconditions: pkt.NetworkHeader() is valid.
-func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
+func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
netHeader := pkt.Network()
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return tupleID{}, tcpip.ErrUnknownProtocol
+ return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
- return tupleID{}, tcpip.ErrUnknownProtocol
+ return tupleID{}, &tcpip.ErrUnknownProtocol{}
}
return tupleID{
@@ -617,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -631,10 +631,10 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
conn, _ := ct.connForTID(tid)
if conn == nil {
// Not a tracked connection.
- return "", 0, tcpip.ErrNotConnected
+ return "", 0, &tcpip.ErrNotConnected{}
} else if conn.manip == manipNone {
// Unmanipulated connection.
- return "", 0, tcpip.ErrInvalidOptionValue
+ return "", 0, &tcpip.ErrInvalidOptionValue{}
}
return conn.original.dstAddr, conn.original.dstPort, nil
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 4908848e9..63a42a2ea 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -41,6 +41,8 @@ const (
protocolNumberOffset = 2
)
+var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
+
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
// use the first three: destination address, source address, and transport
@@ -53,9 +55,7 @@ type fwdTestNetworkEndpoint struct {
dispatcher TransportDispatcher
}
-var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
-
-func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error {
+func (*fwdTestNetworkEndpoint) Enable() tcpip.Error {
return nil
}
@@ -104,7 +104,7 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
-func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
@@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
@@ -124,14 +124,14 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
-func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error {
// The network header should not already be populated.
if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok {
- return tcpip.ErrMalformedHeader
+ return &tcpip.ErrMalformedHeader{}
}
return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
@@ -141,6 +141,21 @@ func (f *fwdTestNetworkEndpoint) Close() {
f.AddressableEndpointState.Cleanup()
}
+// Stats implements stack.NetworkEndpoint.
+func (*fwdTestNetworkEndpoint) Stats() NetworkEndpointStats {
+ return &fwdTestNetworkEndpointStats{}
+}
+
+var _ NetworkEndpointStats = (*fwdTestNetworkEndpointStats)(nil)
+
+type fwdTestNetworkEndpointStats struct{}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*fwdTestNetworkEndpointStats) IsNetworkEndpointStats() {}
+
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
+
// fwdTestNetworkProtocol is a network-layer protocol that implements Address
// resolution.
type fwdTestNetworkProtocol struct {
@@ -158,18 +173,15 @@ type fwdTestNetworkProtocol struct {
}
}
-var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
-var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
-
-func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
-func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
+func (*fwdTestNetworkProtocol) MinimumPacketSize() int {
return fwdTestNetHeaderLen
}
-func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
+func (*fwdTestNetworkProtocol) DefaultPrefixLen() int {
return fwdTestNetDefaultPrefixLen
}
@@ -195,19 +207,19 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddress
return e
}
-func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
-func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
@@ -307,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -323,7 +335,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N
}
// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.WritePacket(r, gso, protocol, pkt)
@@ -356,10 +368,6 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
UseNeighborCache: useNeighborCache,
})
- if !useNeighborCache {
- proto.addrCache = s.linkAddrCache
- }
-
// Enable forwarding.
s.SetForwarding(proto.Number(), true)
@@ -389,13 +397,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
t.Fatal("AddAddress #2 failed:", err)
}
+ nic, ok := s.nics[2]
+ if !ok {
+ t.Fatal("NIC 2 does not exist")
+ }
if useNeighborCache {
// Control the neighbor cache for NIC 2.
- nic, ok := s.nics[2]
- if !ok {
- t.Fatal("failed to get the neighbor cache for NIC 2")
- }
proto.neigh = nic.neigh
+ } else {
+ proto.addrCache = nic.linkAddrCache
}
// Route all packets to NIC 2.
@@ -481,7 +491,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -607,7 +617,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
}
},
},
@@ -692,7 +702,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -768,7 +778,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
@@ -858,7 +868,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
addrResolveDelay: 500 * time.Millisecond,
onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ cache.AddLinkAddress(addr, "c")
},
},
},
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 09c7811fa..63832c200 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -229,7 +229,7 @@ func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
// ReplaceTable replaces or inserts table by name. It panics when an invalid id
// is provided.
-func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -267,11 +267,11 @@ const (
// dropped.
//
// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
-// which address and nicName can be gathered. Currently, address is only
-// needed for prerouting and nicName is only needed for output.
+// which address can be gathered. Currently, address is only needed for
+// prerouting.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
+ if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(pkt, hook, nicName) {
+ if !rule.Filter.match(pkt, hook, inNicName, outNicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName)
if hotdrop {
return RuleDrop, 0
}
@@ -483,11 +483,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
- return "", 0, tcpip.ErrNotConnected
+ return "", 0, &tcpip.ErrNotConnected{}
}
return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 56a3e7861..fd9d61e39 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -210,8 +210,19 @@ type IPHeaderFilter struct {
// filter will match packets that fail the source comparison.
SrcInvert bool
- // OutputInterface matches the name of the outgoing interface for the
- // packet.
+ // InputInterface matches the name of the incoming interface for the packet.
+ InputInterface string
+
+ // InputInterfaceMask masks the characters of the interface name when
+ // comparing with InputInterface.
+ InputInterfaceMask string
+
+ // InputInterfaceInvert inverts the meaning of incoming interface check,
+ // i.e. when true the filter will match packets that fail the incoming
+ // interface comparison.
+ InputInterfaceInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the packet.
OutputInterface string
// OutputInterfaceMask masks the characters of the interface name when
@@ -228,7 +239,7 @@ type IPHeaderFilter struct {
//
// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
// or IPv6 header length.
-func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
// TODO(gvisor.dev/issue/170): Support other filter fields.
@@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return false
}
- // Check the output interface.
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
- if hook == Output {
- n := len(fl.OutputInterface)
- if n == 0 {
- return true
- }
-
- // If the interface name ends with '+', any interface which
- // begins with the name should be matched.
- ifName := fl.OutputInterface
- matches := nicName == ifName
- if strings.HasSuffix(ifName, "+") {
- matches = strings.HasPrefix(nicName, ifName[:n-1])
- }
- return fl.OutputInterfaceInvert != matches
+ switch hook {
+ case Prerouting, Input:
+ return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
+ case Output:
+ return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
+ case Forward, Postrouting:
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ return true
+ default:
+ panic(fmt.Sprintf("unknown hook: %d", hook))
}
+}
- return true
+func matchIfName(nicName string, ifName string, invert bool) bool {
+ n := len(ifName)
+ if n == 0 {
+ // If the interface name is omitted in the filter, any interface will match.
+ return true
+ }
+ // If the interface name ends with '+', any interface which begins with the
+ // name should be matched.
+ var matches bool
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return matches != invert
}
// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
@@ -320,7 +340,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index b600a1cab..930b8f795 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -24,12 +24,16 @@ import (
const linkAddrCacheSize = 512 // max cache entries
+var _ LinkAddressCache = (*linkAddrCache)(nil)
+
// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
//
// The entries are stored in a ring buffer, oldest entry replaced first.
//
// This struct is safe for concurrent use.
type linkAddrCache struct {
+ nic *NIC
+
// ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
@@ -41,9 +45,9 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- cache struct {
+ mu struct {
sync.Mutex
- table map[tcpip.FullAddress]*linkAddrEntry
+ table map[tcpip.Address]*linkAddrEntry
lru linkAddrEntryList
}
}
@@ -77,31 +81,42 @@ type linkAddrEntry struct {
// linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
- // TODO(gvisor.dev/issue/5150): move these fields under mu.
- // mu protects the fields below.
- mu sync.RWMutex
+ cache *linkAddrCache
- addr tcpip.FullAddress
- linkAddr tcpip.LinkAddress
- expiration time.Time
- s entryState
+ mu struct {
+ sync.RWMutex
- // done is closed when address resolution is complete. It is nil iff s is
- // incomplete and resolution is not yet in progress.
- done chan struct{}
+ addr tcpip.Address
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
- // onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ // done is closed when address resolution is complete. It is nil iff s is
+ // incomplete and resolution is not yet in progress.
+ done chan struct{}
+
+ // onResolve is called with the result of address resolution.
+ onResolve []func(LinkResolutionResult)
+ }
}
func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
- for _, callback := range e.onResolve {
- callback(linkAddr, len(linkAddr) != 0)
+ res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0}
+ for _, callback := range e.mu.onResolve {
+ callback(res)
}
- e.onResolve = nil
- if ch := e.done; ch != nil {
+ e.mu.onResolve = nil
+ if ch := e.mu.done; ch != nil {
close(ch)
- e.done = nil
+ e.mu.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0)
}
}
@@ -114,30 +129,30 @@ func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
//
// Precondition: e.mu must be locked
func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) {
- if e.s == incomplete && ns == ready {
- e.notifyCompletionLocked(e.linkAddr)
+ if e.mu.s == incomplete && ns == ready {
+ e.notifyCompletionLocked(e.mu.linkAddr)
}
- if expiration.IsZero() || expiration.After(e.expiration) {
- e.expiration = expiration
+ if expiration.IsZero() || expiration.After(e.mu.expiration) {
+ e.mu.expiration = expiration
}
- e.s = ns
+ e.mu.s = ns
}
// add adds a k -> v mapping to the cache.
-func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+func (c *linkAddrCache) AddLinkAddress(k tcpip.Address, v tcpip.LinkAddress) {
// Calculate expiration time before acquiring the lock, since expiration is
// relative to the time when information was learned, rather than when it
// happened to be inserted into the cache.
expiration := time.Now().Add(c.ageLimit)
- c.cache.Lock()
+ c.mu.Lock()
entry := c.getOrCreateEntryLocked(k)
- c.cache.Unlock()
-
entry.mu.Lock()
defer entry.mu.Unlock()
- entry.linkAddr = v
+ c.mu.Unlock()
+
+ entry.mu.linkAddr = v
entry.changeStateLocked(ready, expiration)
}
@@ -150,19 +165,19 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
- if entry, ok := c.cache.table[k]; ok {
- c.cache.lru.Remove(entry)
- c.cache.lru.PushFront(entry)
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
+ if entry, ok := c.mu.table[k]; ok {
+ c.mu.lru.Remove(entry)
+ c.mu.lru.PushFront(entry)
return entry
}
var entry *linkAddrEntry
- if len(c.cache.table) == linkAddrCacheSize {
- entry = c.cache.lru.Back()
+ if len(c.mu.table) == linkAddrCacheSize {
+ entry = c.mu.lru.Back()
entry.mu.Lock()
- delete(c.cache.table, entry.addr)
- c.cache.lru.Remove(entry)
+ delete(c.mu.table, entry.mu.addr)
+ c.mu.lru.Remove(entry)
// Wake waiters and mark the soon-to-be-reused entry as expired.
entry.notifyCompletionLocked("" /* linkAddr */)
@@ -172,53 +187,56 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
*entry = linkAddrEntry{
- addr: k,
- s: incomplete,
+ cache: c,
}
- c.cache.table[k] = entry
- c.cache.lru.PushFront(entry)
+ entry.mu.Lock()
+ entry.mu.addr = k
+ entry.mu.s = incomplete
+ entry.mu.Unlock()
+ c.mu.table[k] = entry
+ c.mu.lru.PushFront(entry)
return entry
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
- c.cache.Lock()
- defer c.cache.Unlock()
+func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
+ c.mu.Lock()
entry := c.getOrCreateEntryLocked(k)
entry.mu.Lock()
defer entry.mu.Unlock()
+ c.mu.Unlock()
- switch s := entry.s; s {
+ switch s := entry.mu.s; s {
case ready:
- if !time.Now().After(entry.expiration) {
+ if !time.Now().After(entry.mu.expiration) {
// Not expired.
if onResolve != nil {
- onResolve(entry.linkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.mu.linkAddr, Success: true})
}
- return entry.linkAddr, nil, nil
+ return entry.mu.linkAddr, nil, nil
}
entry.changeStateLocked(incomplete, time.Time{})
fallthrough
case incomplete:
if onResolve != nil {
- entry.onResolve = append(entry.onResolve, onResolve)
+ entry.mu.onResolve = append(entry.mu.onResolve, onResolve)
}
- if entry.done == nil {
- entry.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ if entry.mu.done == nil {
+ entry.mu.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, nic, entry.mu.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
- return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ return entry.mu.linkAddr, entry.mu.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
+ linkRes.LinkAddressRequest(k, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
@@ -234,10 +252,10 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has
// succeeded and mark the entry accordingly. Returns true if request can stop,
// false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
- c.cache.Lock()
- defer c.cache.Unlock()
- entry, ok := c.cache.table[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ entry, ok := c.mu.table[k]
if !ok {
// Entry was evicted from the cache.
return true
@@ -245,7 +263,7 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
entry.mu.Lock()
defer entry.mu.Unlock()
- switch s := entry.s; s {
+ switch s := entry.mu.s; s {
case ready:
// Entry was made ready by resolver.
case incomplete:
@@ -255,19 +273,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att
}
// Max number of retries reached, delete entry.
entry.notifyCompletionLocked("" /* linkAddr */)
- delete(c.cache.table, k)
+ delete(c.mu.table, k)
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
return true
}
-func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
c := &linkAddrCache{
+ nic: nic,
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
}
- c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ c.mu.table = make(map[tcpip.Address]*linkAddrEntry, linkAddrCacheSize)
return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index d7ac6cf5f..466a5e8d9 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -26,7 +26,7 @@ import (
)
type testaddr struct {
- addr tcpip.FullAddress
+ addr tcpip.Address
linkAddr tcpip.LinkAddress
}
@@ -35,7 +35,7 @@ var testAddrs = func() []testaddr {
for i := 0; i < 4*linkAddrCacheSize; i++ {
addr := fmt.Sprintf("Addr%06d", i)
addrs = append(addrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ addr: tcpip.Address(addr),
linkAddr: tcpip.LinkAddress("Link" + addr),
})
}
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
// TODO(gvisor.dev/issue/5141): Use a fake clock.
time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
@@ -59,8 +59,8 @@ func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
for _, ta := range testAddrs {
- if ta.addr.Addr == addr {
- r.cache.add(ta.addr, ta.linkAddr)
+ if ta.addr == addr {
+ r.cache.AddLinkAddress(ta.addr, ta.linkAddr)
break
}
}
@@ -77,13 +77,13 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return 1
}
-func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolver) (tcpip.LinkAddress, tcpip.Error) {
var attemptedResolution bool
for {
got, ch, err := c.get(addr, linkRes, "", nil, nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
if attemptedResolution {
- return got, tcpip.ErrTimeout
+ return got, &tcpip.ErrTimeout{}
}
attemptedResolution = true
<-ch
@@ -93,17 +93,23 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
+func newEmptyNIC() *NIC {
+ n := &NIC{}
+ n.linkResQueue.init(n)
+ return n
+}
+
func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ t.Errorf("insert %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("insert %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
@@ -111,25 +117,25 @@ func TestCacheOverflow(t *testing.T) {
e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ t.Errorf("check %d, c.get(%s, nil, '', nil, nil): %s", i, e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("check %d, got c.get(%s, nil, '', nil, nil) = %s, want = %s", i, e.addr, got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
- c.cache.Lock()
- defer c.cache.Unlock()
+ c.mu.Lock()
+ defer c.mu.Unlock()
for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
e := testAddrs[i]
- if entry, ok := c.cache.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
+ if entry, ok := c.mu.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry)
}
}
}
func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
@@ -137,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) {
wg.Add(1)
go func() {
for _, e := range testAddrs {
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
}
wg.Done()
}()
@@ -150,52 +156,53 @@ func TestCacheConcurrent(t *testing.T) {
e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
e = testAddrs[0]
- c.cache.Lock()
- defer c.cache.Unlock()
- if entry, ok := c.cache.table[e.addr]; ok {
- t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry)
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if entry, ok := c.mu.table[e.addr]; ok {
+ t.Errorf("unexpected entry at c.mu.table[%s]: %#v", e.addr, entry)
}
}
func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
e := testAddrs[0]
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
- if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err)
+ _, _, err := c.get(e.addr, linkRes, "", nil, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = ErrWouldBlock", e.addr, err)
}
}
func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
e := testAddrs[0]
l2 := e.linkAddr + "2"
- c.add(e.addr, e.linkAddr)
+ c.AddLinkAddress(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
- c.add(e.addr, l2)
+ c.AddLinkAddress(e.addr, l2)
got, _, err = c.get(e.addr, nil, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, nil, '', nil, nil): %s", e.addr, err)
}
if got != l2 {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
+ t.Errorf("got c.get(%s, nil, '', nil, nil) = %s, want = %s", e.addr, got, l2)
}
}
@@ -206,15 +213,15 @@ func TestCacheResolution(t *testing.T) {
//
// Using a large resolution timeout decreases the probability of experiencing
// this race condition and does not affect how long this test takes to run.
- c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
- t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ t.Errorf("check %d, getBlocking(_, %s, _): %s", i, ta.addr, err)
}
if got != ta.linkAddr {
- t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ t.Errorf("check %d, got getBlocking(_, %s, _) = %s, want = %s", i, ta.addr, got, ta.linkAddr)
}
}
@@ -223,16 +230,16 @@ func TestCacheResolution(t *testing.T) {
e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("c.get(%s, _, '', nil, nil): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got c.get(%s, _, '', nil, nil) = %s, want = %s", e.addr, got, e.linkAddr)
}
}
}
func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
var requestCount uint32
@@ -244,17 +251,18 @@ func TestCacheResolutionFailed(t *testing.T) {
e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
- t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ t.Errorf("getBlocking(_, %s, _): %s", e.addr, err)
}
if got != e.linkAddr {
- t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ t.Errorf("got getBlocking(_, %s, _) = %s, want = %s", e.addr, got, e.linkAddr)
}
before := atomic.LoadUint32(&requestCount)
- e.addr.Addr += "2"
- if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
- t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
+ e.addr += "2"
+ a, err := getBlocking(c, e.addr, linkRes)
+ if _, ok := err.(*tcpip.ErrTimeout); !ok {
+ t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
}
if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
@@ -265,11 +273,12 @@ func TestCacheResolutionFailed(t *testing.T) {
func TestCacheResolutionTimeout(t *testing.T) {
resolverDelay := 500 * time.Millisecond
expiration := resolverDelay / 10
- c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testAddrs[0]
- if a, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrTimeout {
- t.Errorf("got getBlocking(_, %#v, _) = (%s, %s), want = (_, %s)", e.addr, a, err, tcpip.ErrTimeout)
+ a, err := getBlocking(c, e.addr, linkRes)
+ if _, ok := err.(*tcpip.ErrTimeout); !ok {
+ t.Errorf("got getBlocking(_, %s, _) = (%s, %s), want = (_, %s)", e.addr, a, err, &tcpip.ErrTimeout{})
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 61636cae5..64383bc7c 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -45,6 +45,8 @@ const (
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+ defaultPrefixLen = 128
+
// Extra time to use when waiting for an async event to occur.
defaultAsyncPositiveEventTimeout = 10 * time.Second
@@ -102,7 +104,7 @@ type ndpDADEvent struct {
nicID tcpip.NICID
addr tcpip.Address
resolved bool
- err *tcpip.Error
+ err tcpip.Error
}
type ndpRouterEvent struct {
@@ -172,7 +174,7 @@ type ndpDispatcher struct {
}
// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
nicID,
@@ -309,7 +311,7 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
-func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string {
return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
}
@@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
// Should get the address immediately since we should not have performed
@@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) {
default:
t.Fatal("expected DAD event")
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatal(err)
}
// We should not have sent any NDP NS messages.
@@ -440,31 +442,31 @@ func TestDADResolve(t *testing.T) {
NIC: nicID,
}})
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Make sure the address does not resolve before the resolution time has
// passed.
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
}
// Should not get a route even if we specify the local address as the
// tentative address.
{
r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
- if err != tcpip.ErrNoRoute {
- t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{})
}
if r != nil {
r.Release()
@@ -472,8 +474,8 @@ func TestDADResolve(t *testing.T) {
}
{
r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
- if err != tcpip.ErrNoRoute {
- t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, &tcpip.ErrNoRoute{})
}
if r != nil {
r.Release()
@@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if addr.Address != addr1 {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Error(err)
}
// Should get a route using the address now that it is resolved.
{
@@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) {
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Receive a packet to simulate an address conflict.
@@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Attempting to add the address again should not fail if the address's
@@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) {
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
test.stopFn(t, s)
@@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) {
}
if !test.skipFinalAddrCheck {
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
}
@@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) {
}
// Add addresses for each NIC.
- if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
}
- if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
+ addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
}
expectDADEvent(nicID2, addr2)
- if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
+ addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
}
expectDADEvent(nicID3, addr3)
// Address should not be considered bound to NIC(1) yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Should get the address on NIC(2) and NIC(3)
@@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) {
// it as the stack was configured to not do DAD by
// default and we only updated the NDP configurations on
// NIC(1).
- addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr2 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
+ if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
+ t.Fatal(err)
}
- addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr3 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
+ if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
+ t.Fatal(err)
}
// Sleep until right (500ms before) before resolution to
// make sure the address didn't resolve on NIC(1) yet.
const delta = 500 * time.Millisecond
time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Wait for DAD to resolve.
@@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2808,6 +2775,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -2827,10 +2795,15 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useN
Gateway: llAddr3,
NIC: nicID,
}})
+
if useNeighborCache {
- s.AddStaticNeighbor(nicID, llAddr3, linkAddr3)
+ if err := s.AddStaticNeighbor(nicID, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddStaticNeighbor(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ }
} else {
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ if err := s.AddLinkAddress(nicID, llAddr3, linkAddr3); err != nil {
+ t.Fatalf("s.AddLinkAddress(%d, %s, %s): %s", nicID, llAddr3, linkAddr3, err)
+ }
}
return ndpDisp, e, s
}
@@ -2940,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
@@ -3088,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
@@ -3238,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("should not have %s in the list of addresses", addr2)
}
// Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
@@ -3255,8 +3222,11 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
defer ep.Close()
ep.SocketOptions().SetV6Only(true)
- if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ {
+ err := ep.Connect(dstAddr)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, &tcpip.ErrNoRoute{})
+ }
}
})
}
@@ -3615,10 +3585,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index acee72572..88a3ff776 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
// packet prompting NUD/link address resolution.
//
// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, tcpip.Error) {
entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
@@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
- onResolve(entry.neigh.LinkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true})
}
return entry.neigh, nil, nil
case Unknown, Incomplete, Failed:
@@ -154,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.done = make(chan struct{})
}
entry.handlePacketQueuedLocked(localAddr)
- return entry.neigh, entry.done, tcpip.ErrWouldBlock
+ return entry.neigh, entry.done, &tcpip.ErrWouldBlock{}
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
@@ -297,10 +297,9 @@ func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.Li
// no matching entry for the remote address.
}
-// HandleUpperLevelConfirmation implements
-// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in
-// RFC 4861 section 7.3.1.
-func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) {
+// handleUpperLevelConfirmation processes a confirmation of reachablity from
+// some protocol that operates at a layer above the IP/link layer.
+func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) {
n.mu.RLock()
entry, ok := n.cache[addr]
n.mu.RUnlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b96a56612..2870e4f66 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -194,7 +194,7 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
if !r.dropReplies {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
@@ -251,8 +251,8 @@ func TestNeighborCacheGetConfig(t *testing.T) {
// No events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -273,8 +273,8 @@ func TestNeighborCacheSetConfig(t *testing.T) {
// No events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -295,8 +295,9 @@ func TestNeighborCacheEntry(t *testing.T) {
if !ok {
t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -321,11 +322,11 @@ func TestNeighborCacheEntry(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
@@ -335,8 +336,8 @@ func TestNeighborCacheEntry(t *testing.T) {
// No more events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -359,8 +360,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -385,11 +387,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
neigh.removeEntry(entry.Addr)
@@ -407,15 +409,18 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ {
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
+ }
}
}
@@ -462,8 +467,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -506,11 +512,11 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
})
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -531,15 +537,15 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
- return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEntries, c.neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
// No more events should have been dispatched.
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
return nil
@@ -579,8 +585,9 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatal("c.store.entry(0) not found")
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
@@ -603,11 +610,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Remove the entry
@@ -626,11 +633,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -668,11 +675,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Remove the static entry that was just added
@@ -681,8 +688,8 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
// No more events should have been dispatched.
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -712,11 +719,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Add a duplicate entry with a different link address
@@ -736,8 +743,8 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
}
c.nudDisp.mu.Lock()
defer c.nudDisp.mu.Unlock()
- if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
}
@@ -774,11 +781,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Remove the static entry that was just added
@@ -796,11 +803,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -830,8 +837,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatal("c.store.entry(0) not found")
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -854,11 +862,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Override the entry with a static one using the same address
@@ -886,11 +894,11 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -932,8 +940,8 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
LinkAddr: entry.LinkAddr,
State: Static,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
wantEvents := []testEntryEventInfo{
@@ -948,11 +956,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
opts := overflowOptions{
@@ -989,8 +997,9 @@ func TestNeighborCacheClear(t *testing.T) {
if !ok {
t.Fatal("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
@@ -1014,11 +1023,11 @@ func TestNeighborCacheClear(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Add a static entry.
@@ -1037,11 +1046,11 @@ func TestNeighborCacheClear(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1072,8 +1081,8 @@ func TestNeighborCacheClear(t *testing.T) {
}
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEvents, nudDisp.events, eventDiffOptsWithSort()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1094,8 +1103,9 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if !ok {
t.Fatal("c.store.entry(0) not found")
}
- if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
@@ -1118,11 +1128,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
// Clear the cache.
@@ -1140,11 +1150,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
},
}
c.nudDisp.mu.Lock()
- diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, c.nudDisp.events, eventDiffOpts()...)
c.nudDisp.events = nil
c.nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1188,16 +1198,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1225,11 +1232,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1247,16 +1254,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1299,11 +1303,11 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.events = nil
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1331,15 +1335,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
- t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
// No more events should have been dispatched.
nudDisp.mu.Lock()
defer nudDisp.mu.Unlock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
@@ -1366,8 +1370,10 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
+ switch e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err.(type) {
+ case nil, *tcpip.ErrWouldBlock:
+ default:
+ t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, &tcpip.ErrWouldBlock{})
}
}(entry)
}
@@ -1398,8 +1404,8 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
- if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
- t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantUnsortedEntries, neigh.entries(), entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff)
}
}
@@ -1423,16 +1429,13 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1455,8 +1458,8 @@ func TestNeighborCacheReplace(t *testing.T) {
LinkAddr: entry.LinkAddr,
State: Reachable,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
@@ -1489,8 +1492,8 @@ func TestNeighborCacheReplace(t *testing.T) {
LinkAddr: updatedLinkAddr,
State: Delay,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
@@ -1507,8 +1510,8 @@ func TestNeighborCacheReplace(t *testing.T) {
LinkAddr: updatedLinkAddr,
State: Reachable,
}
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
}
}
@@ -1539,16 +1542,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
// First, sanity check that resolution is working
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
clock.Advance(typicalLatency)
select {
@@ -1567,8 +1567,8 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
LinkAddr: entry.LinkAddr,
State: Reachable,
}
- if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-want, +got):\n%s", entry.Addr, diff)
}
// Verify address resolution fails for an unknown address.
@@ -1576,19 +1576,13 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
entry.Addr += "2"
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1627,19 +1621,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1674,19 +1662,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Perform address resolution with a faulty link, which will fail.
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
@@ -1713,13 +1695,13 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Retry address resolution with a working link.
linkRes.dropReplies = false
{
- incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
if incompleteEntry.State != Incomplete {
t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete)
@@ -1772,16 +1754,13 @@ func BenchmarkCacheClear(b *testing.B) {
b.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- b.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
- if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{})
}
select {
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 75afb3001..53ac9bb6e 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -96,7 +96,7 @@ type neighborEntry struct {
done chan struct{}
// onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ onResolve []func(LinkResolutionResult)
isRouter bool
job *tcpip.Job
@@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded}
for _, callback := range e.onResolve {
- callback(e.neigh.LinkAddr, succeeded)
+ callback(res)
}
e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
}
}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index ec34ffa5a..140b8ca00 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -193,7 +193,7 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
p := entryTestProbeInfo{
RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
@@ -266,16 +266,16 @@ func TestEntryInitiallyUnknown(t *testing.T) {
// No probes should have been sent.
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
// No events should have been dispatched.
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -299,16 +299,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
// No probes should have been sent.
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
// No events should have been dispatched.
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff([]testEntryEventInfo(nil), nudDisp.events); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -333,10 +333,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
wantEvents := []testEntryEventInfo{
@@ -352,10 +352,10 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
}
{
nudDisp.mu.Lock()
- diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...)
nudDisp.mu.Unlock()
if diff != "" {
- t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
}
}
@@ -374,10 +374,10 @@ func TestEntryUnknownToStale(t *testing.T) {
// No probes should have been sent.
runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ diff := cmp.Diff([]entryTestProbeInfo(nil), linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
wantEvents := []testEntryEventInfo{
@@ -392,8 +392,8 @@ func TestEntryUnknownToStale(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -427,11 +427,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -453,10 +453,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -483,8 +483,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -515,10 +515,10 @@ func TestEntryIncompleteToReachable(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -553,8 +553,8 @@ func TestEntryIncompleteToReachable(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -579,10 +579,10 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -620,8 +620,8 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -646,10 +646,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -684,8 +684,8 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -710,10 +710,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -744,8 +744,8 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -785,10 +785,10 @@ func TestEntryIncompleteToFailed(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
wantEvents := []testEntryEventInfo{
@@ -812,8 +812,8 @@ func TestEntryIncompleteToFailed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -850,10 +850,10 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -903,8 +903,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -932,10 +932,10 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -977,8 +977,8 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1005,10 +1005,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1054,8 +1054,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -1083,10 +1083,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1134,8 +1134,8 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1157,10 +1157,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1212,8 +1212,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1235,10 +1235,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1290,8 +1290,8 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1313,10 +1313,10 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1358,8 +1358,8 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1381,10 +1381,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1439,8 +1439,8 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1462,10 +1462,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1520,8 +1520,8 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1543,10 +1543,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1601,8 +1601,8 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1624,10 +1624,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1678,8 +1678,8 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1701,10 +1701,10 @@ func TestEntryStaleToDelay(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1752,8 +1752,8 @@ func TestEntryStaleToDelay(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1780,10 +1780,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1851,8 +1851,8 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1880,10 +1880,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -1958,8 +1958,8 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -1987,10 +1987,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2065,8 +2065,8 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2088,10 +2088,10 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2147,8 +2147,8 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2170,10 +2170,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2231,8 +2231,8 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2254,10 +2254,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2319,8 +2319,8 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2343,11 +2343,11 @@ func TestEntryDelayToProbe(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2372,10 +2372,10 @@ func TestEntryDelayToProbe(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2418,8 +2418,8 @@ func TestEntryDelayToProbe(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
@@ -2448,11 +2448,11 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2474,10 +2474,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2539,8 +2539,8 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2563,11 +2563,11 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2589,10 +2589,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2658,8 +2658,8 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2682,11 +2682,11 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2709,10 +2709,10 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2772,8 +2772,8 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2806,10 +2806,10 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -2878,8 +2878,8 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -2907,11 +2907,11 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -2933,10 +2933,10 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3015,8 +3015,8 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3044,11 +3044,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3070,10 +3070,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3149,8 +3149,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3178,11 +3178,11 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3204,10 +3204,10 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3283,8 +3283,8 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3309,11 +3309,11 @@ func TestEntryProbeToFailed(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
}
@@ -3336,11 +3336,11 @@ func TestEntryProbeToFailed(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.probes = nil
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff)
+ t.Fatalf("link address resolver probe #%d mismatch (-want, +got):\n%s", i+1, diff)
}
e.mu.Lock()
@@ -3406,8 +3406,8 @@ func TestEntryProbeToFailed(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
@@ -3449,10 +3449,10 @@ func TestEntryFailedToIncomplete(t *testing.T) {
},
}
linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
+ diff := cmp.Diff(wantProbes, linkRes.probes)
linkRes.mu.Unlock()
if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ t.Fatalf("link address resolver probes mismatch (-want, +got):\n%s", diff)
}
e.mu.Lock()
@@ -3498,8 +3498,8 @@ func TestEntryFailedToIncomplete(t *testing.T) {
},
}
nudDisp.mu.Lock()
- if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
- t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ if diff := cmp.Diff(wantEvents, nudDisp.events, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff)
}
nudDisp.mu.Unlock()
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 0f545f255..e56a624fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -53,6 +53,8 @@ type NIC struct {
// complete.
linkResQueue packetsPendingLinkResolution
+ linkAddrCache *linkAddrCache
+
mu struct {
sync.RWMutex
spoofing bool
@@ -138,7 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
- nic.linkResQueue.init()
+ nic.linkResQueue.init(nic)
+ nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
@@ -167,7 +170,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = new(packetEndpointList)
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, nic.linkAddrCache, nud, nic)
}
nic.LinkEndpoint.Attach(nic)
@@ -228,7 +231,9 @@ func (n *NIC) disableLocked() {
//
// This matches linux's behaviour at the time of writing:
// https://github.com/torvalds/linux/blob/71c061d2443814de15e177489d5cc00a4a253ef3/net/core/neighbour.c#L371
- if err := n.clearNeighbors(); err != nil && err != tcpip.ErrNotSupported {
+ switch err := n.clearNeighbors(); err.(type) {
+ case nil, *tcpip.ErrNotSupported:
+ default:
panic(fmt.Sprintf("n.clearNeighbors(): %s", err))
}
@@ -243,7 +248,7 @@ func (n *NIC) disableLocked() {
// address (ff02::1), start DAD for permanent addresses, and start soliciting
// routers if the stack is not operating as a router. If the stack is also
// configured to auto-generate a link-local address, one will be generated.
-func (n *NIC) enable() *tcpip.Error {
+func (n *NIC) enable() tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -263,7 +268,7 @@ func (n *NIC) enable() *tcpip.Error {
// remove detaches NIC from the link endpoint and releases network endpoint
// resources. This guarantees no packets between this NIC and the network
// stack.
-func (n *NIC) remove() *tcpip.Error {
+func (n *NIC) remove() tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
@@ -299,48 +304,69 @@ func (n *NIC) IsLoopback() bool {
}
// WritePacket implements NetworkLinkEndpoint.
-func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
- // As per relevant RFCs, we should queue packets while we wait for link
- // resolution to complete.
- //
- // RFC 1122 section 2.3.2.2 (for IPv4):
- // The link layer SHOULD save (rather than discard) at least
- // one (the latest) packet of each set of packets destined to
- // the same unresolved IP address, and transmit the saved
- // packet when the address has been resolved.
- //
- // RFC 4861 section 7.2.2 (for IPv6):
- // While waiting for address resolution to complete, the sender MUST, for
- // each neighbor, retain a small queue of packets waiting for address
- // resolution to complete. The queue MUST hold at least one packet, and MAY
- // contain more. However, the number of queued packets per neighbor SHOULD
- // be limited to some small value. When a queue overflows, the new arrival
- // SHOULD replace the oldest entry. Once address resolution completes, the
- // node transmits any queued packets.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- r.Acquire()
- n.linkResQueue.enqueue(ch, r, protocol, pkt)
- return nil
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+ _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
+ return err
+}
+
+func (n *NIC) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+ switch pkt := pkt.(type) {
+ case *PacketBuffer:
+ if err := n.writePacket(r, gso, protocol, pkt); err != nil {
+ return 0, err
}
- return err
+ return 1, nil
+ case *PacketBufferList:
+ return n.writePackets(r, gso, protocol, *pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt))
}
+}
- return n.writePacket(r.Fields(), gso, protocol, pkt)
+func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+ routeInfo, _, err := r.resolvedFields(nil)
+ switch err.(type) {
+ case nil:
+ return n.writePacketBuffer(routeInfo, gso, protocol, pkt)
+ case *tcpip.ErrWouldBlock:
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 7.2.2 (for IPv6):
+ // While waiting for address resolution to complete, the sender MUST, for
+ // each neighbor, retain a small queue of packets waiting for address
+ // resolution to complete. The queue MUST hold at least one packet, and
+ // MAY contain more. However, the number of queued packets per neighbor
+ // SHOULD be limited to some small value. When a queue overflows, the new
+ // arrival SHOULD replace the oldest entry. Once address resolution
+ // completes, the node transmits any queued packets.
+ return n.linkResQueue.enqueue(r, gso, protocol, pkt)
+ default:
+ return 0, err
+ }
}
// WritePacketToRemote implements NetworkInterface.
-func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
var r RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
return n.writePacket(r, gso, protocol, pkt)
}
-func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
return err
}
@@ -351,10 +377,18 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
}
// WritePackets implements NetworkLinkEndpoint.
-func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
- // is being peformed like WritePacket.
- writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
+}
+
+func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
+ }
+
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
@@ -463,15 +497,15 @@ func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber,
// addAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
-func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */)
@@ -535,63 +569,70 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
}
// removeAddress removes an address from n.
-func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removeAddress(addr tcpip.Address) tcpip.Error {
for _, ep := range n.networkEndpoints {
addressableEndpoint, ok := ep.(AddressableEndpoint)
if !ok {
continue
}
- if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress {
+ switch err := addressableEndpoint.RemovePermanentAddress(addr); err.(type) {
+ case *tcpip.ErrBadLocalAddress:
continue
- } else {
+ default:
return err
}
}
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
+}
+
+func (n *NIC) confirmReachable(addr tcpip.Address) {
+ if n := n.neigh; n != nil {
+ n.handleUpperLevelConfirmation(addr)
+ }
}
-func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, tcpip.Error) {
if n.neigh != nil {
entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve)
return entry.LinkAddr, ch, err
}
- return n.stack.linkAddrCache.get(tcpip.FullAddress{NIC: n.ID(), Addr: addr}, linkRes, localAddr, n, onResolve)
+ return n.linkAddrCache.get(addr, linkRes, localAddr, n, onResolve)
}
-func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
+func (n *NIC) neighbors() ([]NeighborEntry, tcpip.Error) {
if n.neigh == nil {
- return nil, tcpip.ErrNotSupported
+ return nil, &tcpip.ErrNotSupported{}
}
return n.neigh.entries(), nil
}
-func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
+func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
n.neigh.addStaticEntry(addr, linkAddress)
return nil
}
-func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error {
+func (n *NIC) removeNeighbor(addr tcpip.Address) tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
if !n.neigh.removeEntry(addr) {
- return tcpip.ErrBadAddress
+ return &tcpip.ErrBadAddress{}
}
return nil
}
-func (n *NIC) clearNeighbors() *tcpip.Error {
+func (n *NIC) clearNeighbors() tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
n.neigh.clear()
@@ -600,7 +641,7 @@ func (n *NIC) clearNeighbors() *tcpip.Error {
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
-func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
// TODO(b/143102137): When implementing MLD, make sure MLD packets are
// not sent unless a valid link-local address is available for use on n
// as an MLD packet's source address must be a link-local address as
@@ -608,12 +649,12 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
ep, ok := n.networkEndpoints[protocol]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
return gep.JoinGroup(addr)
@@ -621,15 +662,15 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// leaveGroup decrements the count for the given multicast address, and when it
// reaches zero removes the endpoint for this address.
-func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
ep, ok := n.networkEndpoints[protocol]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
gep, ok := ep.(GroupAddressableEndpoint)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
return gep.LeaveGroup(addr)
@@ -879,9 +920,9 @@ func (n *NIC) Name() string {
}
// nudConfigs gets the NUD configurations for n.
-func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
+func (n *NIC) nudConfigs() (NUDConfigurations, tcpip.Error) {
if n.neigh == nil {
- return NUDConfigurations{}, tcpip.ErrNotSupported
+ return NUDConfigurations{}, &tcpip.ErrNotSupported{}
}
return n.neigh.config(), nil
}
@@ -890,22 +931,22 @@ func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
+func (n *NIC) setNUDConfigs(c NUDConfigurations) tcpip.Error {
if n.neigh == nil {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
c.resetInvalidFields()
n.neigh.setConfig(c)
return nil
}
-func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
eps, ok := n.mu.packetEPs[netProto]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
eps.add(ep)
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 5b5c58afb..2f719fbe5 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -39,7 +39,7 @@ type testIPv6Endpoint struct {
invalidatedRtr tcpip.Address
}
-func (*testIPv6Endpoint) Enable() *tcpip.Error {
+func (*testIPv6Endpoint) Enable() tcpip.Error {
return nil
}
@@ -65,21 +65,21 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
}
// WritePacket implements NetworkEndpoint.WritePacket.
-func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
return nil
}
// WritePackets implements NetworkEndpoint.WritePackets.
-func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
// Our tests don't use this so we don't support it.
- return 0, tcpip.ErrNotSupported
+ return 0, &tcpip.ErrNotSupported{}
}
// WriteHeaderIncludedPacket implements
// NetworkEndpoint.WriteHeaderIncludedPacket.
-func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) tcpip.Error {
// Our tests don't use this so we don't support it.
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
@@ -99,11 +99,20 @@ func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.invalidatedRtr = rtr
}
-var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+// Stats implements NetworkEndpoint.
+func (*testIPv6Endpoint) Stats() NetworkEndpointStats {
+ return &testIPv6EndpointStats{}
+}
+
+var _ NetworkEndpointStats = (*testIPv6EndpointStats)(nil)
+
+type testIPv6EndpointStats struct{}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*testIPv6EndpointStats) IsNetworkEndpointStats() {}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
-// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
-// believe it supports IPv6.
-//
// We use this instead of ipv6.protocol because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
type testIPv6Protocol struct{}
@@ -140,12 +149,12 @@ func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache,
}
// SetOption implements NetworkProtocol.SetOption.
-func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) tcpip.Error {
return nil
}
// Option implements NetworkProtocol.Option.
-func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) tcpip.Error {
return nil
}
@@ -160,15 +169,13 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo
return 0, false, false
}
-var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
-
// LinkAddressProtocol implements LinkAddressResolver.
func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index 12d67409a..77926e289 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -174,10 +174,6 @@ type NUDHandler interface {
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags)
-
- // HandleUpperLevelConfirmation processes an incoming upper-level protocol
- // (e.g. TCP acknowledgements) reachability confirmation.
- HandleUpperLevelConfirmation(addr tcpip.Address)
}
// NUDConfigurations is the NUD configurations for the netstack. This is used
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 7bca1373e..ebfd5eb45 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -65,8 +65,9 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
// No NIC with ID 1 yet.
config := stack.NUDConfigurations{}
- if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID)
+ err := s.SetNUDConfigurations(1, config)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, &tcpip.ErrUnknownNICID{})
}
}
@@ -90,8 +91,9 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported {
- t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported)
+ _, err := s.NUDConfigurations(nicID)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, &tcpip.ErrNotSupported{})
}
}
@@ -117,8 +119,9 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
}
config := stack.NUDConfigurations{}
- if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported {
- t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported)
+ err := s.SetNUDConfigurations(nicID, config)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, &tcpip.ErrNotSupported{})
}
}
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 41529ffd5..1c651e216 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -28,108 +28,205 @@ const (
maxPendingPacketsPerResolution = 256
)
+// pendingPacketBuffer is a pending packet buffer.
+//
+// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use
+// WritePackets so we can use a PacketBufferList everywhere.
+type pendingPacketBuffer interface {
+ len() int
+}
+
+func (*PacketBuffer) len() int {
+ return 1
+}
+
+func (p *PacketBufferList) len() int {
+ return p.Len()
+}
+
type pendingPacket struct {
- route *Route
- proto tcpip.NetworkProtocolNumber
- pkt *PacketBuffer
+ routeInfo RouteInfo
+ gso *GSO
+ proto tcpip.NetworkProtocolNumber
+ pkt pendingPacketBuffer
}
// packetsPendingLinkResolution is a queue of packets pending link resolution.
//
// Once link resolution completes successfully, the packets will be written.
type packetsPendingLinkResolution struct {
- sync.Mutex
+ nic *NIC
- // The packets to send once the resolver completes.
- packets map[<-chan struct{}][]pendingPacket
+ mu struct {
+ sync.Mutex
- // FIFO of channels used to cancel the oldest goroutine waiting for
- // link-address resolution.
- cancelChans []chan struct{}
-}
+ // The packets to send once the resolver completes.
+ //
+ // The link resolution channel is used as the key for this map.
+ packets map[<-chan struct{}][]pendingPacket
-func (f *packetsPendingLinkResolution) init() {
- f.Lock()
- defer f.Unlock()
- f.packets = make(map[<-chan struct{}][]pendingPacket)
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ //
+ // cancelChans holds the same channels that are used as keys to packets.
+ cancelChans []<-chan struct{}
+ }
}
-func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- f.Lock()
- defer f.Unlock()
+func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
+ n := uint64(pkt.len())
+ f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n)
- packets, ok := f.packets[ch]
- if len(packets) == maxPendingPacketsPerResolution {
- p := packets[0]
- packets[0] = pendingPacket{}
- packets = packets[1:]
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
- p.route.Release()
+ if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
+ ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
}
+}
- if l := len(packets); l >= maxPendingPacketsPerResolution {
- panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+func (f *packetsPendingLinkResolution) init(nic *NIC) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.nic = nic
+ f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
+}
+
+// dequeue any pending packets associated with ch.
+//
+// If success is true, packets will be written and sent to the given remote link
+// address.
+func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) {
+ f.mu.Lock()
+ packets, ok := f.mu.packets[ch]
+ delete(f.mu.packets, ch)
+
+ if ok {
+ for i, cancelChan := range f.mu.cancelChans {
+ if cancelChan == ch {
+ f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
+ break
+ }
+ }
}
- f.packets[ch] = append(packets, pendingPacket{
- route: r,
- proto: proto,
- pkt: pkt,
- })
+ f.mu.Unlock()
if ok {
- return
+ f.dequeuePackets(packets, linkAddr, success)
}
+}
- // Wait for the link-address resolution to complete.
- cancel := f.newCancelChannelLocked()
- go func() {
- cancelled := false
- select {
- case <-ch:
- case <-cancel:
- cancelled = true
- }
+// enqueue a packet to be sent once link resolution completes.
+//
+// If the maximum number of pending resolutions is reached, the packets
+// associated with the oldest link resolution will be dequeued as if they failed
+// link resolution.
+func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+ f.mu.Lock()
+ // Make sure we attempt resolution while holding f's lock so that we avoid
+ // a race where link resolution completes before we enqueue the packets.
+ //
+ // A @ T1: Call ResolvedFields (get link resolution channel)
+ // B @ T2: Complete link resolution, dequeue pending packets
+ // C @ T1: Enqueue packet that already completed link resolution (which will
+ // never dequeue)
+ //
+ // To make sure B does not interleave with A and C, we make sure A and C are
+ // done while holding the lock.
+ routeInfo, ch, err := r.resolvedFields(nil)
+ switch err.(type) {
+ case nil:
+ // The route resolved immediately, so we don't need to wait for link
+ // resolution to send the packet.
+ f.mu.Unlock()
+ return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt)
+ case *tcpip.ErrWouldBlock:
+ // We need to wait for link resolution to complete.
+ default:
+ f.mu.Unlock()
+ return 0, err
+ }
- f.Lock()
- packets, ok := f.packets[ch]
- delete(f.packets, ch)
- f.Unlock()
+ defer f.mu.Unlock()
- if !ok {
- panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
- }
+ packets, ok := f.mu.packets[ch]
+ packets = append(packets, pendingPacket{
+ routeInfo: routeInfo,
+ gso: gso,
+ proto: proto,
+ pkt: pkt,
+ })
- for _, p := range packets {
- if cancelled || p.route.IsResolutionRequired() {
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
+ if len(packets) > maxPendingPacketsPerResolution {
+ f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
+ packets[0] = pendingPacket{}
+ packets = packets[1:]
- if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
- linkResolvableEP.HandleLinkResolutionFailure(pkt)
- }
- } else {
- p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt)
- }
- p.route.Release()
+ if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
}
- }()
+ }
+
+ f.mu.packets[ch] = packets
+
+ if ok {
+ return pkt.len(), nil
+ }
+
+ cancelledPackets := f.newCancelChannelLocked(ch)
+
+ if len(cancelledPackets) != 0 {
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as handing link resolution failures may be a costly operation.
+ go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */)
+ }
+
+ return pkt.len(), nil
}
-// newCancelChannel creates a channel that can cancel a pending forwarding
-// activity. The oldest channel is closed if the number of open channels would
-// exceed maxPendingResolutions.
-func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
- if len(f.cancelChans) == maxPendingResolutions {
- ch := f.cancelChans[0]
- f.cancelChans[0] = nil
- f.cancelChans = f.cancelChans[1:]
- close(ch)
+// newCancelChannelLocked appends the link resolution channel to a FIFO. If the
+// maximum number of pending resolutions is reached, the oldest channel will be
+// removed and its associated pending packets will be returned.
+func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
+ f.mu.cancelChans = append(f.mu.cancelChans, newCH)
+ if len(f.mu.cancelChans) <= maxPendingResolutions {
+ return nil
}
- if l := len(f.cancelChans); l >= maxPendingResolutions {
+
+ ch := f.mu.cancelChans[0]
+ f.mu.cancelChans[0] = nil
+ f.mu.cancelChans = f.mu.cancelChans[1:]
+ if l := len(f.mu.cancelChans); l > maxPendingResolutions {
panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
}
- ch := make(chan struct{})
- f.cancelChans = append(f.cancelChans, ch)
- return ch
+ packets, ok := f.mu.packets[ch]
+ if !ok {
+ panic("must have a packet queue for an uncancelled channel")
+ }
+ delete(f.mu.packets, ch)
+
+ return packets
+}
+
+func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) {
+ for _, p := range packets {
+ if success {
+ p.routeInfo.RemoteLinkAddress = linkAddr
+ _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
+ } else {
+ f.incrementOutgoingPacketErrors(p.proto, p.pkt)
+
+ if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok {
+ switch pkt := p.pkt.(type) {
+ case *PacketBuffer:
+ linkResolvableEP.HandleLinkResolutionFailure(pkt)
+ case *PacketBufferList:
+ for pb := pkt.Front(); pb != nil; pb = pb.Next() {
+ linkResolvableEP.HandleLinkResolutionFailure(pb)
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
+ }
+ }
+ }
+ }
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 68c113b6a..510da8689 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -172,10 +172,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -184,7 +184,7 @@ type TransportProtocol interface {
// ParsePorts returns the source and destination ports stored in a
// packet of this protocol.
- ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+ ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
// protocol that don't match any existing endpoint. For example,
@@ -197,12 +197,12 @@ type TransportProtocol interface {
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error
+ SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error
+ Option(option tcpip.GettableTransportProtocolOption) tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -289,10 +289,10 @@ type NetworkHeaderParams struct {
// endpoints may associate themselves with the same identifier (group address).
type GroupAddressableEndpoint interface {
// JoinGroup joins the specified group.
- JoinGroup(group tcpip.Address) *tcpip.Error
+ JoinGroup(group tcpip.Address) tcpip.Error
// LeaveGroup attempts to leave the specified group.
- LeaveGroup(group tcpip.Address) *tcpip.Error
+ LeaveGroup(group tcpip.Address) tcpip.Error
// IsInGroup returns true if the endpoint is a member of the specified group.
IsInGroup(group tcpip.Address) bool
@@ -440,17 +440,17 @@ func (k AddressKind) IsPermanent() bool {
type AddressableEndpoint interface {
// AddAndAcquirePermanentAddress adds the passed permanent address.
//
- // Returns tcpip.ErrDuplicateAddress if the address exists.
+ // Returns *tcpip.ErrDuplicateAddress if the address exists.
//
// Acquires and returns the AddressEndpoint for the added address.
- AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error)
+ AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error)
// RemovePermanentAddress removes the passed address if it is a permanent
// address.
//
- // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed
+ // Returns *tcpip.ErrBadLocalAddress if the endpoint does not have the passed
// permanent address.
- RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
+ RemovePermanentAddress(addr tcpip.Address) tcpip.Error
// MainAddress returns the endpoint's primary permanent address.
MainAddress() tcpip.AddressWithPrefix
@@ -512,14 +512,14 @@ type NetworkInterface interface {
Promiscuous() bool
// WritePacketToRemote writes the packet to the given remote link address.
- WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
+ WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePacket writes a packet with the given protocol through the given
// route.
//
// WritePacket takes ownership of the packet buffer. The packet buffer's
// network and transport header must be set.
- WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
+ WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol through the given
// route. Must not be called with an empty list of packet buffers.
@@ -529,7 +529,7 @@ type NetworkInterface interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
}
// LinkResolvableNetworkEndpoint handles link resolution events.
@@ -547,8 +547,8 @@ type NetworkEndpoint interface {
// Must only be called when the stack is in a state that allows the endpoint
// to send and receive packets.
//
- // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled.
- Enable() *tcpip.Error
+ // Returns *tcpip.ErrNotPermitted if the endpoint cannot be enabled.
+ Enable() tcpip.Error
// Enabled returns true if the endpoint is enabled.
Enabled() bool
@@ -574,16 +574,16 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It takes ownership of pkt. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length. It takes ownership of pkts and
// underlying packets.
- WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address. It takes ownership of pkt.
- WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) tcpip.Error
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
@@ -597,6 +597,26 @@ type NetworkEndpoint interface {
// NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for
// this endpoint.
NetworkProtocolNumber() tcpip.NetworkProtocolNumber
+
+ // Stats returns a reference to the network endpoint stats.
+ Stats() NetworkEndpointStats
+}
+
+// NetworkEndpointStats is the interface implemented by each network endpoint
+// stats struct.
+type NetworkEndpointStats interface {
+ // IsNetworkEndpointStats is an empty method to implement the
+ // NetworkEndpointStats marker interface.
+ IsNetworkEndpointStats()
+}
+
+// IPNetworkEndpointStats is a NetworkEndpointStats that tracks IP-related
+// statistics.
+type IPNetworkEndpointStats interface {
+ NetworkEndpointStats
+
+ // IPStats returns the IP statistics of a network endpoint.
+ IPStats() *tcpip.IPStats
}
// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets.
@@ -634,12 +654,12 @@ type NetworkProtocol interface {
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error
+ SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error
+ Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -776,7 +796,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
+ WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol and route. Must not be
// called with an empty list of packet buffers.
@@ -786,7 +806,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -801,7 +821,7 @@ type InjectableLinkEndpoint interface {
// link.
//
// dest is used by endpoints with multiple raw destinations.
- InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
+ InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -813,7 +833,7 @@ type LinkAddressResolver interface {
//
// The request is sent from the passed network interface. If the interface
// local address is unspecified, any interface local address may be used.
- LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
@@ -829,12 +849,8 @@ type LinkAddressResolver interface {
// A LinkAddressCache caches link addresses.
type LinkAddressCache interface {
- // CheckLocalAddress determines if the given local address exists, and if it
- // does not exist.
- CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
-
// AddLinkAddress adds a link address to the cache.
- AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+ AddLinkAddress(addr tcpip.Address, linkAddr tcpip.LinkAddress)
}
// RawFactory produces endpoints for writing various types of raw packets.
@@ -842,11 +858,11 @@ type RawFactory interface {
// NewUnassociatedEndpoint produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can
// be used to write arbitrary packets that include the network header.
- NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
// NewPacketEndpoint produces endpoints for reading and writing packets
// that include network and (when cooked is false) link layer headers.
- NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error)
}
// GSOType is the type of GSO segments.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1ff7b3a37..4ae0f2a1a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -86,12 +86,21 @@ type RouteInfo struct {
RemoteLinkAddress tcpip.LinkAddress
}
-// Fields returns a RouteInfo with all of r's exported fields. This allows
-// callers to store the route's fields without retaining a reference to it.
+// Fields returns a RouteInfo with all of the known values for the route's
+// fields.
+//
+// If any fields are unknown (e.g. remote link address when it is waiting for
+// link address resolution), they will be unset.
func (r *Route) Fields() RouteInfo {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.fieldsLocked()
+}
+
+func (r *Route) fieldsLocked() RouteInfo {
return RouteInfo{
routeInfo: r.routeInfo,
- RemoteLinkAddress: r.RemoteLinkAddress(),
+ RemoteLinkAddress: r.mu.remoteLinkAddress,
}
}
@@ -306,32 +315,43 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary.
+// ResolvedFieldsResult is the result of a route resolution attempt.
+type ResolvedFieldsResult struct {
+ RouteInfo RouteInfo
+ Success bool
+}
+
+// ResolvedFields attempts to resolve the remote link address if it is not
+// known.
//
-// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
-// waiting for ARP reply). If address resolution is required, a notification
-// channel is also returned for the caller to block on. The channel is closed
-// once address resolution is complete (successful or not). If a callback is
-// provided, it will be called when address resolution is complete, regardless
+// If a callback is provided, it will be called before ResolvedFields returns
+// when address resolution is not required. If address resolution is required,
+// the callback will be called once address resolution is complete, regardless
// of success or failure.
-func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
- r.mu.Lock()
-
- if !r.isResolutionRequiredRLocked() {
- // Nothing to do if there is no cache (which does the resolution on cache miss) or
- // link address is already known.
- r.mu.Unlock()
- return nil, nil
- }
-
- // Increment the route's reference count because finishResolution retains a
- // reference to the route and releases it when called.
- r.acquireLocked()
- r.mu.Unlock()
+//
+// Note, the route will not cache the remote link address when address
+// resolution completes.
+func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) tcpip.Error {
+ _, _, err := r.resolvedFields(afterResolve)
+ return err
+}
- nextAddr := r.NextHop
- if nextAddr == "" {
- nextAddr = r.RemoteAddress
+// resolvedFields is like ResolvedFields but also returns a notification channel
+// when address resolution is required. This channel will become readable once
+// address resolution is complete.
+//
+// The route's fields will also be returned, regardless of whether address
+// resolution is required or not.
+func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, tcpip.Error) {
+ r.mu.RLock()
+ fields := r.fieldsLocked()
+ resolutionRequired := r.isResolutionRequiredRLocked()
+ r.mu.RUnlock()
+ if !resolutionRequired {
+ if afterResolve != nil {
+ afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true})
+ }
+ return fields, nil, nil
}
// If specified, the local address used for link address resolution must be an
@@ -341,18 +361,27 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
- finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
- if ok {
- r.ResolveWith(linkAddress)
- }
+ afterResolveFields := fields
+ linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(r.nextHop(), linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) {
if afterResolve != nil {
- afterResolve()
+ if r.Success {
+ afterResolveFields.RemoteLinkAddress = r.LinkAddress
+ }
+
+ afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success})
}
- r.Release()
+ })
+ if err == nil {
+ fields.RemoteLinkAddress = linkAddr
}
+ return fields, ch, err
+}
- _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
- return ch, err
+func (r *Route) nextHop() tcpip.Address {
+ if len(r.NextHop) == 0 {
+ return r.RemoteAddress
+ }
+ return r.NextHop
}
// local returns true if the route is a local route.
@@ -371,11 +400,7 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isResolutionRequiredRLocked() bool {
- if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
- return false
- }
-
- return r.linkRes != nil
+ return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {
@@ -404,9 +429,9 @@ func (r *Route) isValidForOutgoingRLocked() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
@@ -414,9 +439,9 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
-func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
if !r.isValidForOutgoing() {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
@@ -424,9 +449,9 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
@@ -496,3 +521,12 @@ func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.RemoteAddress)
}
+
+// ConfirmReachable informs the network/link layer that the neighbour used for
+// the route is reachable.
+//
+// "Reachable" is defined as having full-duplex communication between the
+// local and remote ends of the route.
+func (r *Route) ConfirmReachable() {
+ r.outgoingNIC.confirmReachable(r.nextHop())
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index c0aec61a6..119c4c505 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -76,12 +76,16 @@ type TCPCubicState struct {
// TCPRACKState is used to hold a copy of the internal RACK state when the
// TCPProbeFunc is invoked.
type TCPRACKState struct {
- XmitTime time.Time
- EndSequence seqnum.Value
- FACK seqnum.Value
- RTT time.Duration
- Reord bool
- DSACKSeen bool
+ XmitTime time.Time
+ EndSequence seqnum.Value
+ FACK seqnum.Value
+ RTT time.Duration
+ Reord bool
+ DSACKSeen bool
+ ReoWnd time.Duration
+ ReoWndIncr uint8
+ ReoWndPersist int8
+ RTTSeq seqnum.Value
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
@@ -382,8 +386,6 @@ type Stack struct {
stats tcpip.Stats
- linkAddrCache *linkAddrCache
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
@@ -446,7 +448,7 @@ type Stack struct {
// sendBufferSize holds the min/default/max send buffer sizes for
// endpoints other than TCP.
- sendBufferSize SendBufferSizeOption
+ sendBufferSize tcpip.SendBufferSizeOption
// receiveBufferSize holds the min/default/max receive buffer sizes for
// endpoints other than TCP.
@@ -554,7 +556,7 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
@@ -572,11 +574,11 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
- return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{}
}
case header.IPv6AddressSize:
if len(addr.Addr) == header.IPv4AddressSize {
- return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable
+ return tcpip.FullAddress{}, 0, &tcpip.ErrNetworkUnreachable{}
}
}
@@ -584,10 +586,10 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
case netProto == t.NetProto:
case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
- return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
+ return tcpip.FullAddress{}, 0, &tcpip.ErrNoRoute{}
}
default:
- return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ return tcpip.FullAddress{}, 0, &tcpip.ErrInvalidEndpointState{}
}
return addr, netProto, nil
@@ -636,7 +638,6 @@ func New(opts Options) *Stack {
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
cleanupEndpoints: make(map[TransportEndpoint]struct{}),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
stats: opts.Stats.FillIn(),
@@ -649,7 +650,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
randomGenerator: mathrand.New(randSrc),
- sendBufferSize: SendBufferSizeOption{
+ sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
@@ -701,10 +702,10 @@ func (s *Stack) UniqueID() uint64 {
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return netProto.SetOption(option)
}
@@ -718,10 +719,10 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op
// if err != nil {
// ...
// }
-func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return netProto.Option(option)
}
@@ -730,10 +731,10 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error {
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return transProtoState.proto.SetOption(option)
}
@@ -745,10 +746,10 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
-func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error {
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return transProtoState.proto.Option(option)
}
@@ -781,15 +782,15 @@ func (s *Stack) Stats() tcpip.Stats {
// SetForwarding enables or disables packet forwarding between NICs for the
// passed protocol.
-func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error {
+func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error {
protocol, ok := s.networkProtocols[protocolNum]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol)
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
forwardingProtocol.SetForwarding(enable)
@@ -852,10 +853,10 @@ func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
}
// NewEndpoint creates a new transport layer endpoint of the given protocol.
-func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
t, ok := s.transportProtocols[transport]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return t.proto.NewEndpoint(network, waiterQueue)
@@ -864,9 +865,9 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// NewRawEndpoint creates a new raw transport layer endpoint of the given
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
-func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
if s.rawFactory == nil {
- return nil, tcpip.ErrNotPermitted
+ return nil, &tcpip.ErrNotPermitted{}
}
if !associated {
@@ -875,7 +876,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
t, ok := s.transportProtocols[transport]
if !ok {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return t.proto.NewRawEndpoint(network, waiterQueue)
@@ -883,9 +884,9 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// NewPacketEndpoint creates a new packet endpoint listening for the given
// netProto.
-func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
if s.rawFactory == nil {
- return nil, tcpip.ErrNotPermitted
+ return nil, &tcpip.ErrNotPermitted{}
}
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
@@ -916,20 +917,20 @@ type NICOptions struct {
// NICs can be configured.
//
// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
-func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
// Make sure id is unique.
if _, ok := s.nics[id]; ok {
- return tcpip.ErrDuplicateNICID
+ return &tcpip.ErrDuplicateNICID{}
}
// Make sure name is unique, unless unnamed.
if opts.Name != "" {
for _, n := range s.nics {
if n.Name() == opts.Name {
- return tcpip.ErrDuplicateNICID
+ return &tcpip.ErrDuplicateNICID{}
}
}
}
@@ -945,7 +946,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
-func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
@@ -963,26 +964,26 @@ func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
-func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) EnableNIC(id tcpip.NICID) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.enable()
}
// DisableNIC disables the given NIC.
-func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) DisableNIC(id tcpip.NICID) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
nic.disable()
@@ -1003,7 +1004,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
}
// RemoveNIC removes NIC and all related routes from the network stack.
-func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) RemoveNIC(id tcpip.NICID) tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -1013,10 +1014,10 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
// removeNICLocked removes NIC and all related routes from the network stack.
//
// s.mu must be locked.
-func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
+func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error {
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
delete(s.nics, id)
@@ -1050,6 +1051,9 @@ type NICInfo struct {
Stats NICStats
+ // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC.
+ NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats
+
// Context is user-supplied data optionally supplied in CreateNICWithOptions.
// See type NICOptions for more details.
Context NICContext
@@ -1081,6 +1085,12 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Promiscuous: nic.Promiscuous(),
Loopback: nic.IsLoopback(),
}
+
+ netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats)
+ for proto, netEP := range nic.networkEndpoints {
+ netStats[proto] = netEP.Stats()
+ }
+
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.LinkEndpoint.LinkAddress(),
@@ -1088,6 +1098,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
+ NetworkStats: netStats,
Context: nic.context,
ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
@@ -1111,13 +1122,13 @@ type NICStateFlags struct {
}
// AddAddress adds a new network-layer address to the specified NIC.
-func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error {
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
// AddAddressWithPrefix is the same as AddAddress, but allows you to specify
// the address prefix.
-func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error {
+func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error {
ap := tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: addr,
@@ -1127,16 +1138,16 @@ func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProto
// AddProtocolAddress adds a new network-layer protocol address to the
// specified NIC.
-func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error {
+func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error {
return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint)
}
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
// whether the new endpoint can be primary or not.
-func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error {
netProto, ok := s.networkProtocols[protocol]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{
Protocol: protocol,
@@ -1149,13 +1160,13 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt
// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows
// you to specify whether the new endpoint can be primary or not.
-func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
+func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.addAddress(protocolAddress, peb)
@@ -1163,7 +1174,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
// RemoveAddress removes an existing network-layer address from the specified
// NIC.
-func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1171,7 +1182,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
return nic.removeAddress(addr)
}
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
// AllAddresses returns a map of NICIDs to their protocol addresses (primary
@@ -1189,19 +1200,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
// GetMainNICAddress returns the first non-deprecated primary address and prefix
// for the given NIC and protocol. If no non-deprecated primary address exists,
-// a deprecated primary address and prefix will be returned. Returns an error if
+// a deprecated primary address and prefix will be returned. Returns false if
// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
// address for the given protocol.
-func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ return tcpip.AddressWithPrefix{}, false
}
- return nic.primaryAddress(protocol), nil
+ return nic.primaryAddress(protocol), true
}
func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
@@ -1301,7 +1312,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr,
// If no local address is provided, the stack will select a local address. If no
// remote address is provided, the stack wil use a remote address equal to the
// local address.
-func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) {
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1337,9 +1348,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if isLoopback {
- return nil, tcpip.ErrBadLocalAddress
+ return nil, &tcpip.ErrBadLocalAddress{}
}
- return nil, tcpip.ErrNetworkUnreachable
+ return nil, &tcpip.ErrNetworkUnreachable{}
}
canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
@@ -1405,7 +1416,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
}
- return nil, tcpip.ErrNoRoute
+ return nil, &tcpip.ErrNoRoute{}
}
if id == 0 {
@@ -1425,12 +1436,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
if needRoute {
- return nil, tcpip.ErrNoRoute
+ return nil, &tcpip.ErrNoRoute{}
}
if header.IsV6LoopbackAddress(remoteAddr) {
- return nil, tcpip.ErrBadLocalAddress
+ return nil, &tcpip.ErrBadLocalAddress{}
}
- return nil, tcpip.ErrNetworkUnreachable
+ return nil, &tcpip.ErrNetworkUnreachable{}
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1476,13 +1487,13 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
}
// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
-func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
nic.setPromiscuousMode(enable)
@@ -1492,13 +1503,13 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error
// SetSpoofing enables or disables address spoofing in the given NIC, allowing
// endpoints to bind to any address in the NIC.
-func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
nic.setSpoofing(enable)
@@ -1506,17 +1517,27 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
return nil
}
-// AddLinkAddress adds a link address to the stack link cache.
-func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
- s.linkAddrCache.add(fullAddr, linkAddr)
- // TODO: provide a way for a transport endpoint to receive a signal
- // that AddLinkAddress for a particular address has been called.
+// AddLinkAddress adds a link address for the neighbor on the specified NIC.
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return &tcpip.ErrUnknownNICID{}
+ }
+
+ nic.linkAddrCache.AddLinkAddress(neighbor, linkAddr)
+ return nil
+}
+
+// LinkResolutionResult is the result of a link address resolution attempt.
+type LinkResolutionResult struct {
+ LinkAddress tcpip.LinkAddress
+ Success bool
}
-// GetLinkAddress finds the link address corresponding to a neighbor's address.
-//
-// Returns a link address for the remote address, if readily available.
+// GetLinkAddress finds the link address corresponding to a network address.
//
// Returns ErrNotSupported if the stack is not configured with a link address
// resolver for the specified network protocol.
@@ -1525,53 +1546,56 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t
// with a notification channel for the caller to block on. Triggers address
// resolution asynchronously.
//
-// If onResolve is provided, it will be called either immediately, if
-// resolution is not required, or when address resolution is complete, with
-// the resolved link address and whether resolution succeeded. After any
-// callbacks have been called, the returned notification channel is closed.
+// onResolve will be called either immediately, if resolution is not required,
+// or when address resolution is complete, with the resolved link address and
+// whether resolution succeeded.
//
// If specified, the local address must be an address local to the interface
// the neighbor cache belongs to. The local address is the source address of
// a packet prompting NUD/link address resolution.
-//
-// TODO(gvisor.dev/issue/5151): Don't return the link address.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return "", nil, tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
linkRes, ok := s.linkAddrResolvers[protocol]
if !ok {
- return "", nil, tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
+ }
+
+ if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
+ onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
+ return nil
}
- return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ return err
}
// Neighbors returns all IP to MAC address associations.
-func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
+func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, tcpip.Error) {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return nil, tcpip.ErrUnknownNICID
+ return nil, &tcpip.ErrUnknownNICID{}
}
return nic.neighbors()
}
// AddStaticNeighbor statically associates an IP address to a MAC address.
-func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
+func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.addStaticNeighbor(addr, linkAddr)
@@ -1580,26 +1604,26 @@ func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAdd
// RemoveNeighbor removes an IP to MAC address association previously created
// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there
// is no association with the provided address.
-func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.removeNeighbor(addr)
}
// ClearNeighbors removes all IP to MAC address associations.
-func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
+func (s *Stack) ClearNeighbors(nicID tcpip.NICID) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.clearNeighbors()
@@ -1609,25 +1633,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
// the stack transport dispatcher.
-func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
-func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.cleanupEndpointsMu.Lock()
s.cleanupEndpoints[ep] = struct{}{}
s.cleanupEndpointsMu.Unlock()
@@ -1652,13 +1676,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
-func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
-func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
@@ -1762,7 +1786,7 @@ func (s *Stack) Resume() {
// RegisterPacketEndpoint registers ep with the stack, causing it to receive
// all traffic of the specified netProto on the given NIC. If nicID is 0, it
// receives traffic from every NIC.
-func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -1781,7 +1805,7 @@ func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.Network
// Capture on a specific device.
nic, ok := s.nics[nicID]
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
return err
@@ -1819,12 +1843,12 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
// WritePacketToRemote writes a payload on the specified NIC using the provided
// network protocol and remote link address.
-func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error {
s.mu.Lock()
nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
- return tcpip.ErrUnknownDevice
+ return &tcpip.ErrUnknownDevice{}
}
pkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
@@ -1889,37 +1913,37 @@ func (s *Stack) RemoveTCPProbe() {
}
// JoinGroup joins the given multicast group on the given NIC.
-func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
return nic.joinGroup(protocol, multicastAddr)
}
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
// LeaveGroup leaves the given multicast group on the given NIC.
-func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
+func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
return nic.leaveGroup(protocol, multicastAddr)
}
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
// IsInGroup returns true if the NIC with ID nicID has joined the multicast
// group multicastAddr.
-func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[nicID]; ok {
return nic.isInGroup(multicastAddr), nil
}
- return false, tcpip.ErrUnknownNICID
+ return false, &tcpip.ErrUnknownNICID{}
}
// IPTables returns the stack's iptables.
@@ -1959,26 +1983,26 @@ func (s *Stack) AllowICMPMessage() bool {
// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol
// number installed on the specified NIC.
-func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) {
+func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
nic, ok := s.nics[nicID]
if !ok {
- return nil, tcpip.ErrUnknownNICID
+ return nil, &tcpip.ErrUnknownNICID{}
}
return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
-func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) {
+func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, tcpip.Error) {
s.mu.RLock()
nic, ok := s.nics[id]
s.mu.RUnlock()
if !ok {
- return NUDConfigurations{}, tcpip.ErrUnknownNICID
+ return NUDConfigurations{}, &tcpip.ErrUnknownNICID{}
}
return nic.nudConfigs()
@@ -1988,13 +2012,13 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err
//
// Note, if c contains invalid NUD configuration values, it will be fixed to
// use default values for the erroneous values.
-func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error {
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[id]
s.mu.RUnlock()
if !ok {
- return tcpip.ErrUnknownNICID
+ return &tcpip.ErrUnknownNICID{}
}
return nic.setNUDConfigs(c)
@@ -2036,7 +2060,7 @@ func generateRandInt64() int64 {
}
// FindNetworkEndpoint returns the network endpoint for the given address.
-func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
+func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -2048,7 +2072,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
addressEndpoint.DecRef()
return nic.getNetworkEndpoint(netProto), nil
}
- return nil, tcpip.ErrBadAddress
+ return nil, &tcpip.ErrBadAddress{}
}
// FindNICNameFromID returns the name of the NIC for the given NICID.
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
index 0b093e6c5..8d9b20b7e 100644
--- a/pkg/tcpip/stack/stack_options.go
+++ b/pkg/tcpip/stack/stack_options.go
@@ -14,7 +14,9 @@
package stack
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// MinBufferSize is the smallest size of a receive or send buffer.
@@ -29,14 +31,6 @@ const (
DefaultMaxBufferSize = 4 << 20 // 4 MiB
)
-// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
-// get/set the default, min and max send buffer sizes.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max receive buffer sizes.
type ReceiveBufferSizeOption struct {
@@ -46,17 +40,17 @@ type ReceiveBufferSizeOption struct {
}
// SetOption allows setting stack wide options.
-func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+func (s *Stack) SetOption(option interface{}) tcpip.Error {
switch v := option.(type) {
- case SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
if v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
@@ -68,11 +62,11 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error {
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
if v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
s.mu.Lock()
@@ -81,14 +75,14 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option allows retrieving stack wide options.
-func (s *Stack) Option(option interface{}) *tcpip.Error {
+func (s *Stack) Option(option interface{}) tcpip.Error {
switch v := option.(type) {
- case *SendBufferSizeOption:
+ case *tcpip.SendBufferSizeOption:
s.mu.RLock()
*v = s.sendBufferSize
s.mu.RUnlock()
@@ -101,6 +95,6 @@ func (s *Stack) Option(option interface{}) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7e935ddff..41f95811f 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,6 +60,15 @@ const (
protocolNumberOffset = 2
)
+func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error {
+ if addr, ok := s.GetMainNICAddress(nicID, proto); !ok {
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto)
+ } else if addr != want {
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want)
+ }
+ return nil
+}
+
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
// received packets; the counts of all endpoints are aggregated in the protocol
// descriptor.
@@ -81,7 +90,7 @@ type fakeNetworkEndpoint struct {
dispatcher stack.TransportDispatcher
}
-func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+func (f *fakeNetworkEndpoint) Enable() tcpip.Error {
f.mu.Lock()
defer f.mu.Unlock()
f.mu.enabled = true
@@ -145,7 +154,7 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
-func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
@@ -153,7 +162,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -176,18 +185,30 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
func (f *fakeNetworkEndpoint) Close() {
f.AddressableEndpointState.Cleanup()
}
+// Stats implements NetworkEndpoint.
+func (*fakeNetworkEndpoint) Stats() stack.NetworkEndpointStats {
+ return &fakeNetworkEndpointStats{}
+}
+
+var _ stack.NetworkEndpointStats = (*fakeNetworkEndpointStats)(nil)
+
+type fakeNetworkEndpointStats struct{}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (*fakeNetworkEndpointStats) IsNetworkEndpointStats() {}
+
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
@@ -202,15 +223,15 @@ type fakeNetworkProtocol struct {
}
}
-func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fakeNetNumber
}
-func (f *fakeNetworkProtocol) MinimumPacketSize() int {
+func (*fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
-func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
+func (*fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
@@ -232,23 +253,23 @@ func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.Li
return e
}
-func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
+func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
f.defaultTTL = uint8(*v)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
-func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
+func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.DefaultTTLOption:
*v = tcpip.DefaultTTLOption(f.defaultTTL)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -397,7 +418,7 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return err
@@ -406,7 +427,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
return send(r, payload)
}
-func send(r *stack.Route, payload buffer.View) *tcpip.Error {
+func send(r *stack.Route, payload buffer.View) tcpip.Error {
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
@@ -435,14 +456,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer
}
}
-func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
}
}
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) {
t.Helper()
if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
@@ -579,8 +600,8 @@ func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr,
func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
_, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Fatalf("FindRoute returned unexpected error, got = %v, want = %s", err, &tcpip.ErrNoRoute{})
}
}
@@ -628,8 +649,9 @@ func TestDisableUnknownNIC(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ err := s.DisableNIC(1)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{})
}
}
@@ -687,8 +709,9 @@ func TestRemoveUnknownNIC(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
})
- if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
- t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ err := s.RemoveNIC(1)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, &tcpip.ErrUnknownNICID{})
}
}
@@ -731,8 +754,8 @@ func TestRemoveNIC(t *testing.T) {
func TestRouteWithDownNIC(t *testing.T) {
tests := []struct {
name string
- downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
- upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ downFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) tcpip.Error
}{
{
name: "Disabled NIC",
@@ -890,15 +913,15 @@ func TestRouteWithDownNIC(t *testing.T) {
if err := test.downFn(s, nicID1); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
}
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
testSend(t, r2, ep2, buf)
// Writes with Routes that use NIC2 after being brought down should fail.
if err := test.downFn(s, nicID2); err != nil {
t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
}
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{})
+ testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
if upFn := test.upFn; upFn != nil {
// Writes with Routes that use NIC1 after being brought up should
@@ -911,7 +934,7 @@ func TestRouteWithDownNIC(t *testing.T) {
t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
}
testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{})
}
})
}
@@ -1036,11 +1059,12 @@ func TestAddressRemoval(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
+ err := s.RemoveAddress(1, localAddr)
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok {
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{})
}
}
@@ -1087,12 +1111,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
t.Fatal("RemoveAddress failed:", err)
}
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
- testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
- t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
+ {
+ err := s.RemoveAddress(1, localAddr)
+ if _, ok := err.(*tcpip.ErrBadLocalAddress); !ok {
+ t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, &tcpip.ErrBadLocalAddress{})
+ }
}
}
@@ -1186,7 +1213,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
// 2. Add Address, everything should work.
@@ -1214,7 +1241,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
// 4. Add Address back, everything should work again.
@@ -1253,8 +1280,8 @@ func TestEndpointExpiration(t *testing.T) {
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{})
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
// 7. Add Address back, everything should work again.
@@ -1290,7 +1317,7 @@ func TestEndpointExpiration(t *testing.T) {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{})
}
})
}
@@ -1333,8 +1360,8 @@ func TestPromiscuousMode(t *testing.T) {
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
- if err != tcpip.ErrNoRoute {
- t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Fatalf("FindRoute returned unexpected error: got = %v, want = %s", err, &tcpip.ErrNoRoute{})
}
// Set promiscuous mode to false, then check that packet can't be
@@ -1540,7 +1567,7 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
// Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{})
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
@@ -1590,8 +1617,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
s.SetRouteTable([]tcpip.Route{})
// If there is no endpoint, it won't work.
- if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ {
+ _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok {
+ t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{})
+ }
}
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
@@ -1610,8 +1640,11 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
}
// If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ {
+ _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if _, ok := err.(*tcpip.ErrNetworkUnreachable); !ok {
+ t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, &tcpip.ErrNetworkUnreachable{})
+ }
}
}
@@ -1753,9 +1786,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
anyAddr = header.IPv6Any
}
- want := tcpip.ErrNetworkUnreachable
+ var want tcpip.Error = &tcpip.ErrNetworkUnreachable{}
if tc.routeNeeded {
- want = tcpip.ErrNoRoute
+ want = &tcpip.ErrNoRoute{}
}
// If there is no endpoint, it won't work.
@@ -1769,8 +1802,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded {
// Route table is empty but we need a route, this should cause an error.
- if err != tcpip.ErrNoRoute {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, tcpip.ErrNoRoute)
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, &tcpip.ErrNoRoute{})
}
} else {
if err != nil {
@@ -1861,20 +1894,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
+ gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if len(primaryAddrAdded) == 0 {
// No primary addresses present.
if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr)
}
} else {
// At least one primary address was added, verify the returned
// address is in the list of primary addresses we added.
if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded)
}
}
})
@@ -1915,12 +1948,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
- t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
@@ -1928,12 +1957,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get no address after removal.
- gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2102,7 +2127,7 @@ func TestCreateNICWithOptions(t *testing.T) {
type callArgsAndExpect struct {
nicID tcpip.NICID
opts stack.NICOptions
- err *tcpip.Error
+ err tcpip.Error
}
tests := []struct {
@@ -2120,7 +2145,7 @@ func TestCreateNICWithOptions(t *testing.T) {
{
nicID: tcpip.NICID(1),
opts: stack.NICOptions{Name: "eth2"},
- err: tcpip.ErrDuplicateNICID,
+ err: &tcpip.ErrDuplicateNICID{},
},
},
},
@@ -2135,7 +2160,7 @@ func TestCreateNICWithOptions(t *testing.T) {
{
nicID: tcpip.NICID(2),
opts: stack.NICOptions{Name: "lo"},
- err: tcpip.ErrDuplicateNICID,
+ err: &tcpip.ErrDuplicateNICID{},
},
},
},
@@ -2165,7 +2190,7 @@ func TestCreateNICWithOptions(t *testing.T) {
{
nicID: tcpip.NICID(1),
opts: stack.NICOptions{},
- err: tcpip.ErrDuplicateNICID,
+ err: &tcpip.ErrDuplicateNICID{},
},
},
},
@@ -2474,12 +2499,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
}
- gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ // Check that we get no address after removal.
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
- if gotMainAddr != expectedMainAddr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
+ if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2525,12 +2550,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2561,12 +2582,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// Address should not be considered bound to the
// NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
@@ -2584,12 +2601,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil {
+ t.Fatal(err)
}
}
@@ -2621,17 +2634,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
t.Fatal("AddAddressWithOptions failed:", err)
}
- addr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
+ addr, ok := s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if pi == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
}
} else if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
}
{
@@ -2710,18 +2723,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
if err := s.RemoveAddress(1, "\x03"); err != nil {
t.Fatalf("RemoveAddress failed: %v", err)
}
- addr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatalf("s.GetMainNICAddress failed: %v", err)
+ addr, ok = s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
-
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
}
} else {
if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
}
}
})
@@ -3247,12 +3259,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Address should be tentative so it should not be a main address.
- got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Enabling the NIC should start DAD for the address.
@@ -3264,12 +3272,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Wait for DAD to resolve.
@@ -3284,12 +3288,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
}
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
// Enabling the NIC again should be a no-op.
@@ -3299,12 +3299,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
}
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
}
@@ -3313,14 +3309,14 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
testCases := []struct {
name string
rs stack.ReceiveBufferSizeOption
- err *tcpip.Error
+ err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
- {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
{"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
@@ -3352,31 +3348,32 @@ func TestStackSendBufferSizeOption(t *testing.T) {
const sMin = stack.MinBufferSize
testCases := []struct {
name string
- ss stack.SendBufferSizeOption
- err *tcpip.Error
+ ss tcpip.SendBufferSizeOption
+ err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
- {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
- {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := stack.New(stack.Options{})
defer s.Close()
- if err := s.SetOption(tc.ss); err != tc.err {
- t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ err := s.SetOption(tc.ss)
+ if diff := cmp.Diff(tc.err, err); diff != "" {
+ t.Fatalf("unexpected error from s.SetOption(%+v), (-want, +got):\n%s", tc.ss, diff)
}
- var ss stack.SendBufferSizeOption
if tc.err == nil {
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err != nil {
t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
}
@@ -3778,20 +3775,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- } else if gotAddr != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
// Should still get the address when the NIC is diabled.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("DisableNIC(%d): %s", nicID, err)
}
- if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- } else if gotAddr != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
}
@@ -3939,7 +3932,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
addrNIC tcpip.NICID
localAddr tcpip.Address
- findRouteErr *tcpip.Error
+ findRouteErr tcpip.Error
dependentOnForwarding bool
}{
{
@@ -3948,7 +3941,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: false,
addrNIC: nicID1,
localAddr: fakeNetCfg.nic2Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -3957,7 +3950,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: true,
addrNIC: nicID1,
localAddr: fakeNetCfg.nic2Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -3966,7 +3959,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: false,
addrNIC: nicID1,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4002,7 +3995,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: false,
addrNIC: nicID2,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4011,7 +4004,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
forwardingEnabled: true,
addrNIC: nicID2,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4035,7 +4028,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: fakeNetCfg,
forwardingEnabled: false,
localAddr: fakeNetCfg.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4051,7 +4044,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
addrNIC: nicID1,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4059,7 +4052,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
addrNIC: nicID1,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4067,7 +4060,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: false,
localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4075,7 +4068,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
forwardingEnabled: true,
localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNoRoute,
+ findRouteErr: &tcpip.ErrNoRoute{},
dependentOnForwarding: false,
},
{
@@ -4107,7 +4100,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: false,
localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4115,7 +4108,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
forwardingEnabled: true,
localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4123,7 +4116,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: false,
localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4131,7 +4124,7 @@ func TestFindRouteWithForwarding(t *testing.T) {
netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
forwardingEnabled: true,
localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
- findRouteErr: tcpip.ErrNetworkUnreachable,
+ findRouteErr: &tcpip.ErrNetworkUnreachable{},
dependentOnForwarding: false,
},
{
@@ -4186,8 +4179,8 @@ func TestFindRouteWithForwarding(t *testing.T) {
if r != nil {
defer r.Release()
}
- if err != test.findRouteErr {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ if diff := cmp.Diff(test.findRouteErr, err); diff != "" {
+ t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff)
}
if test.findRouteErr != nil {
@@ -4234,8 +4227,11 @@ func TestFindRouteWithForwarding(t *testing.T) {
if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
}
- if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ {
+ err := send(r, data)
+ if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, &tcpip.ErrInvalidEndpointState{})
+ }
}
if n := ep1.Drain(); n != 0 {
t.Errorf("got %d unexpected packets from ep1", n)
@@ -4297,8 +4293,9 @@ func TestWritePacketToRemote(t *testing.T) {
}
t.Run("InvalidNICID", func(t *testing.T) {
- if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want {
- t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want)
+ err := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView())
+ if _, ok := err.(*tcpip.ErrUnknownDevice); !ok {
+ t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", err, &tcpip.ErrUnknownDevice{})
}
pkt, ok := e.Read()
if got, want := ok, false; got != want {
@@ -4372,10 +4369,64 @@ func TestGetLinkAddressErrors(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID)
+ {
+ err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil)
+ if _, ok := err.(*tcpip.ErrUnknownNICID); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrUnknownNICID{})
+ }
+ }
+ {
+ err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil)
+ if _, ok := err.(*tcpip.ErrNotSupported); !ok {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, &tcpip.ErrNotSupported{})
+ }
+ }
+}
+
+func TestStaticGetLinkAddress(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4",
+ proto: ipv4.ProtocolNumber,
+ addr: header.IPv4Broadcast,
+ expectedLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "IPv6",
+ proto: ipv6.ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
}
- if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ch := make(chan stack.LinkResolutionResult, 1)
+ if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err)
+ }
+
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 07b2818d2..26eceb804 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -205,7 +205,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.Lock()
defer epsByNIC.mu.Unlock()
@@ -222,7 +222,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
return multiPortEp.singleRegisterEndpoint(t, flags)
}
-func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
epsByNIC.mu.RLock()
defer epsByNIC.mu.RUnlock()
@@ -294,7 +294,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
@@ -306,7 +306,7 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// checkEndpoint checks if an endpoint can be registered with the dispatcher.
-func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
for _, n := range netProtos {
if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
return err
@@ -403,7 +403,7 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
@@ -412,7 +412,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
- return tcpip.ErrPortInUse
+ return &tcpip.ErrPortInUse{}
}
}
@@ -422,7 +422,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
return nil
}
-func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error {
ep.mu.RLock()
defer ep.mu.RUnlock()
@@ -431,7 +431,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
- return tcpip.ErrPortInUse
+ return &tcpip.ErrPortInUse{}
}
}
@@ -456,7 +456,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports
return len(ep.endpoints) == 0
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
@@ -464,7 +464,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
eps.mu.Lock()
@@ -482,7 +482,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
}
-func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
if id.RemotePort != 0 {
// SO_REUSEPORT only applies to bound/listening endpoints.
flags.LoadBalanced = false
@@ -490,7 +490,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return tcpip.ErrUnknownProtocol
+ return &tcpip.ErrUnknownProtocol{}
}
eps.mu.RLock()
@@ -649,10 +649,10 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
// that packets of the appropriate protocol are delivered to it. A single
// packet can be sent to one or more raw endpoints along with a non-raw
// endpoint.
-func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
eps.mu.Lock()
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 57e1f8354..10cbbe589 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -175,9 +175,9 @@ func TestTransportDemuxerRegister(t *testing.T) {
for _, test := range []struct {
name string
proto tcpip.NetworkProtocolNumber
- want *tcpip.Error
+ want tcpip.Error
}{
- {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}},
{"success", ipv4.ProtocolNumber, nil},
} {
t.Run(test.name, func(t *testing.T) {
@@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
+ if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
@@ -294,7 +294,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
defer wq.EventUnregister(&we)
defer close(ch)
- var err *tcpip.Error
+ var err tcpip.Error
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 9d39533a1..cf5de747b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "bytes"
"io"
"testing"
@@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
-func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
- ep.ops.InitHandler(ep)
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
+ ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits)
return ep
}
@@ -86,19 +87,20 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
return tcpip.ReadResult{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
Data: buffer.View(v).ToVectorisedView(),
@@ -112,42 +114,42 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
}
// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
- return -1, tcpip.ErrUnknownProtocolOption
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*fakeTransportEndpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*fakeTransportEndpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
-func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
r.Release()
return err
@@ -162,22 +164,22 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 {
return f.uniqueID
}
-func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
+func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
return nil
}
func (*fakeTransportEndpoint) Reset() {
}
-func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
+func (*fakeTransportEndpoint) Listen(int) tcpip.Error {
return nil
}
-func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
@@ -186,9 +188,8 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
return a, nil, nil
}
-func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
+func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) tcpip.Error {
if err := f.proto.stack.RegisterTransportEndpoint(
- a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
@@ -202,11 +203,11 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
return tcpip.FullAddress{}, nil
}
-func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
return tcpip.FullAddress{}, nil
}
@@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
peerAddr: route.RemoteAddress,
route: route,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits)
f.acceptQueue = append(f.acceptQueue, ep)
}
@@ -251,7 +252,7 @@ func (*fakeTransportEndpoint) Resume(*stack.Stack) {}
func (*fakeTransportEndpoint) Wait() {}
-func (*fakeTransportEndpoint) LastError() *tcpip.Error {
+func (*fakeTransportEndpoint) LastError() tcpip.Error {
return nil
}
@@ -279,19 +280,19 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
+func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ return newFakeTransportEndpoint(f, netProto, f.stack), nil
}
-func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return nil, tcpip.ErrUnknownProtocol
+func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
+ return nil, &tcpip.ErrUnknownProtocol{}
}
func (*fakeTransportProtocol) MinimumPacketSize() int {
return fakeTransHeaderLen
}
-func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err tcpip.Error) {
return 0, 0, nil
}
@@ -299,23 +300,23 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndp
return stack.UnknownDestinationPacketHandled
}
-func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
+func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPModerateReceiveBufferOption:
f.opts.good = bool(*v)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
-func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
+func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPModerateReceiveBufferOption:
*v = tcpip.TCPModerateReceiveBufferOption(f.opts.good)
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) {
}
// Create buffer that will hold the payload.
- view := buffer.NewView(30)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ b := make([]byte, 30)
+ var r bytes.Reader
+ r.Reset(b)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("write failed: %v", err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 56aac093c..c500a0d1c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -29,6 +29,7 @@
package tcpip
import (
+ "bytes"
"errors"
"fmt"
"io"
@@ -46,141 +47,6 @@ import (
// Using header.IPv4AddressSize would cause an import cycle.
const ipv4AddressSize = 4
-// Error represents an error in the netstack error space. Using a special type
-// ensures that errors outside of this space are not accidentally introduced.
-//
-// All errors must have unique msg strings.
-//
-// +stateify savable
-type Error struct {
- msg string
-
- ignoreStats bool
-}
-
-// String implements fmt.Stringer.String.
-func (e *Error) String() string {
- if e == nil {
- return "<nil>"
- }
- return e.msg
-}
-
-// IgnoreStats indicates whether this error type should be included in failure
-// counts in tcpip.Stats structs.
-func (e *Error) IgnoreStats() bool {
- return e.ignoreStats
-}
-
-// Errors that can be returned by the network stack.
-var (
- ErrUnknownProtocol = &Error{msg: "unknown protocol"}
- ErrUnknownNICID = &Error{msg: "unknown nic id"}
- ErrUnknownDevice = &Error{msg: "unknown device"}
- ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
- ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
- ErrDuplicateAddress = &Error{msg: "duplicate address"}
- ErrNoRoute = &Error{msg: "no route"}
- ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
- ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
- ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
- ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
- ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
- ErrNoPortAvailable = &Error{msg: "no ports are available"}
- ErrPortInUse = &Error{msg: "port is in use"}
- ErrBadLocalAddress = &Error{msg: "bad local address"}
- ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
- ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
- ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
- ErrConnectionRefused = &Error{msg: "connection was refused"}
- ErrTimeout = &Error{msg: "operation timed out"}
- ErrAborted = &Error{msg: "operation aborted"}
- ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
- ErrDestinationRequired = &Error{msg: "destination address is required"}
- ErrNotSupported = &Error{msg: "operation not supported"}
- ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
- ErrNotConnected = &Error{msg: "endpoint not connected"}
- ErrConnectionReset = &Error{msg: "connection reset by peer"}
- ErrConnectionAborted = &Error{msg: "connection aborted"}
- ErrNoSuchFile = &Error{msg: "no such file"}
- ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
- ErrBadAddress = &Error{msg: "bad address"}
- ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
- ErrMessageTooLong = &Error{msg: "message too long"}
- ErrNoBufferSpace = &Error{msg: "no buffer space available"}
- ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
- ErrNotPermitted = &Error{msg: "operation not permitted"}
- ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
- ErrMalformedHeader = &Error{msg: "header is malformed"}
- ErrBadBuffer = &Error{msg: "bad buffer"}
-)
-
-var messageToError map[string]*Error
-
-var populate sync.Once
-
-// StringToError converts an error message to the error.
-func StringToError(s string) *Error {
- populate.Do(func() {
- var errors = []*Error{
- ErrUnknownProtocol,
- ErrUnknownNICID,
- ErrUnknownDevice,
- ErrUnknownProtocolOption,
- ErrDuplicateNICID,
- ErrDuplicateAddress,
- ErrNoRoute,
- ErrBadLinkEndpoint,
- ErrAlreadyBound,
- ErrInvalidEndpointState,
- ErrAlreadyConnecting,
- ErrAlreadyConnected,
- ErrNoPortAvailable,
- ErrPortInUse,
- ErrBadLocalAddress,
- ErrClosedForSend,
- ErrClosedForReceive,
- ErrWouldBlock,
- ErrConnectionRefused,
- ErrTimeout,
- ErrAborted,
- ErrConnectStarted,
- ErrDestinationRequired,
- ErrNotSupported,
- ErrQueueSizeNotSupported,
- ErrNotConnected,
- ErrConnectionReset,
- ErrConnectionAborted,
- ErrNoSuchFile,
- ErrInvalidOptionValue,
- ErrBadAddress,
- ErrNetworkUnreachable,
- ErrMessageTooLong,
- ErrNoBufferSpace,
- ErrBroadcastDisabled,
- ErrNotPermitted,
- ErrAddressFamilyNotSupported,
- ErrMalformedHeader,
- ErrBadBuffer,
- }
-
- messageToError = make(map[string]*Error)
- for _, e := range errors {
- if messageToError[e.String()] != nil {
- panic("tcpip errors with duplicated message: " + e.String())
- }
- messageToError[e.String()] = e
- }
- })
-
- e, ok := messageToError[s]
- if !ok {
- panic("unknown error message: " + s)
- }
-
- return e
-}
-
// Errors related to Subnet
var (
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
@@ -194,7 +60,7 @@ type ErrSaveRejection struct {
}
// Error returns a sensible description of the save rejection error.
-func (e ErrSaveRejection) Error() string {
+func (e *ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
@@ -471,30 +337,15 @@ type FullAddress struct {
// This interface allows the endpoint to request the amount of data it needs
// based on internal buffers without exposing them.
type Payloader interface {
- // FullPayload returns all available bytes.
- FullPayload() ([]byte, *Error)
+ io.Reader
- // Payload returns a slice containing at most size bytes.
- Payload(size int) ([]byte, *Error)
+ // Len returns the number of bytes of the unread portion of the
+ // Reader.
+ Len() int
}
-// SlicePayload implements Payloader for slices.
-//
-// This is typically used for tests.
-type SlicePayload []byte
-
-// FullPayload implements Payloader.FullPayload.
-func (s SlicePayload) FullPayload() ([]byte, *Error) {
- return s, nil
-}
-
-// Payload implements Payloader.Payload.
-func (s SlicePayload) Payload(size int) ([]byte, *Error) {
- if size > len(s) {
- size = len(s)
- }
- return s[:size], nil
-}
+var _ Payloader = (*bytes.Buffer)(nil)
+var _ Payloader = (*bytes.Reader)(nil)
var _ io.Writer = (*SliceWriter)(nil)
@@ -647,7 +498,7 @@ type Endpoint interface {
// If non-zero number of bytes are successfully read and written to dst, err
// must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
// should be returned.
- Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error)
+ Read(io.Writer, ReadOptions) (ReadResult, Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -662,7 +513,7 @@ type Endpoint interface {
// stream (TCP) Endpoints may return partial writes, and even then only
// in the case where writing additional data would block. Other Endpoints
// will either write the entire message or return an error.
- Write(Payloader, WriteOptions) (int64, *Error)
+ Write(Payloader, WriteOptions) (int64, Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
@@ -676,21 +527,21 @@ type Endpoint interface {
// connected returns nil. Calling connect again results in ErrAlreadyConnected.
// Anything else -- the attempt to connect failed.
//
- // If address.Addr is empty, this means that Enpoint has to be
+ // If address.Addr is empty, this means that Endpoint has to be
// disconnected if this is supported, otherwise
// ErrAddressFamilyNotSupported must be returned.
- Connect(address FullAddress) *Error
+ Connect(address FullAddress) Error
// Disconnect disconnects the endpoint from its peer.
- Disconnect() *Error
+ Disconnect() Error
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
- Shutdown(flags ShutdownFlags) *Error
+ Shutdown(flags ShutdownFlags) Error
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
- Listen(backlog int) *Error
+ Listen(backlog int) Error
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode. This method does not
@@ -700,36 +551,36 @@ type Endpoint interface {
//
// If peerAddr is not nil then it is populated with the peer address of the
// returned endpoint.
- Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error)
+ Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
- Bind(address FullAddress) *Error
+ Bind(address FullAddress) Error
// GetLocalAddress returns the address to which the endpoint is bound.
- GetLocalAddress() (FullAddress, *Error)
+ GetLocalAddress() (FullAddress, Error)
// GetRemoteAddress returns the address to which the endpoint is
// connected.
- GetRemoteAddress() (FullAddress, *Error)
+ GetRemoteAddress() (FullAddress, Error)
// Readiness returns the current readiness of the endpoint. For example,
// if waiter.EventIn is set, the endpoint is immediately readable.
Readiness(mask waiter.EventMask) waiter.EventMask
// SetSockOpt sets a socket option.
- SetSockOpt(opt SettableSocketOption) *Error
+ SetSockOpt(opt SettableSocketOption) Error
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
- SetSockOptInt(opt SockOptInt, v int) *Error
+ SetSockOptInt(opt SockOptInt, v int) Error
// GetSockOpt gets a socket option.
- GetSockOpt(opt GettableSocketOption) *Error
+ GetSockOpt(opt GettableSocketOption) Error
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
- GetSockOptInt(SockOptInt) (int, *Error)
+ GetSockOptInt(SockOptInt) (int, Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -752,7 +603,7 @@ type Endpoint interface {
SetOwner(owner PacketOwner)
// LastError clears and returns the last error reported by the endpoint.
- LastError() *Error
+ LastError() Error
// SocketOptions returns the structure which contains all the socket
// level options.
@@ -840,10 +691,6 @@ const (
// number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption
- // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
- // specify the send buffer size option.
- SendBufferSizeOption
-
// ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the receive buffer size option.
ReceiveBufferSizeOption
@@ -1011,12 +858,54 @@ type SettableSocketOption interface {
isSettableSocketOption()
}
+// CongestionControlState indicates the current congestion control state for
+// TCP sender.
+type CongestionControlState int
+
+const (
+ // Open indicates that the sender is receiving acks in order and
+ // no loss or dupACK's etc have been detected.
+ Open CongestionControlState = iota
+ // RTORecovery indicates that an RTO has occurred and the sender
+ // has entered an RTO based recovery phase.
+ RTORecovery
+ // FastRecovery indicates that the sender has entered FastRecovery
+ // based on receiving nDupAck's. This state is entered only when
+ // SACK is not in use.
+ FastRecovery
+ // SACKRecovery indicates that the sender has entered SACK based
+ // recovery.
+ SACKRecovery
+ // Disorder indicates the sender either received some SACK blocks
+ // or dupACK's.
+ Disorder
+)
+
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO(b/64800844): Add and populate stat fields.
type TCPInfoOption struct {
- RTT time.Duration
+ // RTT is the smoothed round trip time.
+ RTT time.Duration
+
+ // RTTVar is the round trip time variation.
RTTVar time.Duration
+
+ // RTO is the retransmission timeout for the endpoint.
+ RTO time.Duration
+
+ // CcState is the congestion control state.
+ CcState CongestionControlState
+
+ // SndCwnd is the congestion window, in packets.
+ SndCwnd uint32
+
+ // SndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ SndSsthresh uint32
+
+ // ReorderSeen indicates if reordering is seen in the endpoint.
+ ReorderSeen bool
}
func (*TCPInfoOption) isGettableSocketOption() {}
@@ -1248,6 +1137,31 @@ type IPPacketInfo struct {
DestinationAddr Address
}
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ // Min is the minimum size for send buffer.
+ Min int
+
+ // Default is the default size for send buffer.
+ Default int
+
+ // Max is the maximum size for send buffer.
+ Max int
+}
+
+// GetSendBufferLimits is used to get the send buffer size limits.
+type GetSendBufferLimits func(StackHandler) SendBufferSizeOption
+
+// GetStackSendBufferLimits is used to get default, min and max send buffer size.
+func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption {
+ var ss SendBufferSizeOption
+ if err := so.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ return ss
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
@@ -1317,8 +1231,33 @@ func (s *StatCounter) String() string {
return strconv.FormatUint(s.Value(), 10)
}
+// A MultiCounterStat keeps track of two counters at once.
+type MultiCounterStat struct {
+ a, b *StatCounter
+}
+
+// Init sets both internal counters to point to a and b.
+func (m *MultiCounterStat) Init(a, b *StatCounter) {
+ m.a = a
+ m.b = b
+}
+
+// Increment adds one to the counters.
+func (m *MultiCounterStat) Increment() {
+ m.a.Increment()
+ m.b.Increment()
+}
+
+// IncrementBy increments the counters by v.
+func (m *MultiCounterStat) IncrementBy(v uint64) {
+ m.a.IncrementBy(v)
+ m.b.IncrementBy(v)
+}
+
// ICMPv4PacketStats enumerates counts for all ICMPv4 packet types.
type ICMPv4PacketStats struct {
+ // LINT.IfChange(ICMPv4PacketStats)
+
// Echo is the total number of ICMPv4 echo packets counted.
Echo *StatCounter
@@ -1358,10 +1297,56 @@ type ICMPv4PacketStats struct {
// InfoReply is the total number of ICMPv4 information reply packets
// counted.
InfoReply *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4PacketStats)
+}
+
+// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
+type ICMPv4SentPacketStats struct {
+ // LINT.IfChange(ICMPv4SentPacketStats)
+
+ ICMPv4PacketStats
+
+ // Dropped is the total number of ICMPv4 packets dropped due to link
+ // layer errors.
+ Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv4 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4SentPacketStats)
+}
+
+// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
+type ICMPv4ReceivedPacketStats struct {
+ // LINT.IfChange(ICMPv4ReceivedPacketStats)
+
+ ICMPv4PacketStats
+
+ // Invalid is the total number of invalid ICMPv4 packets received.
+ Invalid *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4ReceivedPacketStats)
+}
+
+// ICMPv4Stats collects ICMPv4-specific stats.
+type ICMPv4Stats struct {
+ // LINT.IfChange(ICMPv4Stats)
+
+ // PacketsSent contains statistics about sent packets.
+ PacketsSent ICMPv4SentPacketStats
+
+ // PacketsReceived contains statistics about received packets.
+ PacketsReceived ICMPv4ReceivedPacketStats
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterICMPv4Stats)
}
// ICMPv6PacketStats enumerates counts for all ICMPv6 packet types.
type ICMPv6PacketStats struct {
+ // LINT.IfChange(ICMPv6PacketStats)
+
// EchoRequest is the total number of ICMPv6 echo request packets
// counted.
EchoRequest *StatCounter
@@ -1416,32 +1401,14 @@ type ICMPv6PacketStats struct {
// MulticastListenerDone is the total number of Multicast Listener Done
// messages counted.
MulticastListenerDone *StatCounter
-}
-
-// ICMPv4SentPacketStats collects outbound ICMPv4-specific stats.
-type ICMPv4SentPacketStats struct {
- ICMPv4PacketStats
-
- // Dropped is the total number of ICMPv4 packets dropped due to link
- // layer errors.
- Dropped *StatCounter
- // RateLimited is the total number of ICMPv6 packets dropped due to
- // rate limit being exceeded.
- RateLimited *StatCounter
-}
-
-// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
-type ICMPv4ReceivedPacketStats struct {
- ICMPv4PacketStats
-
- // Invalid is the total number of ICMPv4 packets received that the
- // transport layer could not parse.
- Invalid *StatCounter
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6PacketStats)
}
// ICMPv6SentPacketStats collects outbound ICMPv6-specific stats.
type ICMPv6SentPacketStats struct {
+ // LINT.IfChange(ICMPv6SentPacketStats)
+
ICMPv6PacketStats
// Dropped is the total number of ICMPv6 packets dropped due to link
@@ -1451,47 +1418,41 @@ type ICMPv6SentPacketStats struct {
// RateLimited is the total number of ICMPv6 packets dropped due to
// rate limit being exceeded.
RateLimited *StatCounter
+
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6SentPacketStats)
}
// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
type ICMPv6ReceivedPacketStats struct {
+ // LINT.IfChange(ICMPv6ReceivedPacketStats)
+
ICMPv6PacketStats
// Unrecognized is the total number of ICMPv6 packets received that the
// transport layer does not know how to parse.
Unrecognized *StatCounter
- // Invalid is the total number of ICMPv6 packets received that the
- // transport layer could not parse.
+ // Invalid is the total number of invalid ICMPv6 packets received.
Invalid *StatCounter
// RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets
// dropped due to being router-specific packets.
RouterOnlyPacketsDroppedByHost *StatCounter
-}
-
-// ICMPv4Stats collects ICMPv4-specific stats.
-type ICMPv4Stats struct {
- // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type
- // and a single count of packets which failed to write to the link
- // layer.
- PacketsSent ICMPv4SentPacketStats
- // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4
- // packet type and a single count of invalid packets received.
- PacketsReceived ICMPv4ReceivedPacketStats
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6ReceivedPacketStats)
}
// ICMPv6Stats collects ICMPv6-specific stats.
type ICMPv6Stats struct {
- // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type
- // and a single count of packets which failed to write to the link
- // layer.
+ // LINT.IfChange(ICMPv6Stats)
+
+ // PacketsSent contains statistics about sent packets.
PacketsSent ICMPv6SentPacketStats
- // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6
- // packet type and a single count of invalid packets received.
+ // PacketsReceived contains statistics about received packets.
PacketsReceived ICMPv6ReceivedPacketStats
+
+ // LINT.ThenChange(network/ipv6/stats.go:multiCounterICMPv6Stats)
}
// ICMPStats collects ICMP-specific stats (both v4 and v6).
@@ -1505,6 +1466,8 @@ type ICMPStats struct {
// IGMPPacketStats enumerates counts for all IGMP packet types.
type IGMPPacketStats struct {
+ // LINT.IfChange(IGMPPacketStats)
+
// MembershipQuery is the total number of Membership Query messages counted.
MembershipQuery *StatCounter
@@ -1518,22 +1481,29 @@ type IGMPPacketStats struct {
// LeaveGroup is the total number of Leave Group messages counted.
LeaveGroup *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPPacketStats)
}
// IGMPSentPacketStats collects outbound IGMP-specific stats.
type IGMPSentPacketStats struct {
+ // LINT.IfChange(IGMPSentPacketStats)
+
IGMPPacketStats
// Dropped is the total number of IGMP packets dropped.
Dropped *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPSentPacketStats)
}
// IGMPReceivedPacketStats collects inbound IGMP-specific stats.
type IGMPReceivedPacketStats struct {
+ // LINT.IfChange(IGMPReceivedPacketStats)
+
IGMPPacketStats
- // Invalid is the total number of IGMP packets received that IGMP could not
- // parse.
+ // Invalid is the total number of invalid IGMP packets received.
Invalid *StatCounter
// ChecksumErrors is the total number of IGMP packets dropped due to bad
@@ -1543,21 +1513,27 @@ type IGMPReceivedPacketStats struct {
// Unrecognized is the total number of unrecognized messages counted, these
// are silently ignored for forward-compatibilty.
Unrecognized *StatCounter
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPReceivedPacketStats)
}
-// IGMPStats colelcts IGMP-specific stats.
+// IGMPStats collects IGMP-specific stats.
type IGMPStats struct {
- // IGMPSentPacketStats contains counts of sent packets by IGMP packet type
- // and a single count of invalid packets received.
+ // LINT.IfChange(IGMPStats)
+
+ // PacketsSent contains statistics about sent packets.
PacketsSent IGMPSentPacketStats
- // IGMPReceivedPacketStats contains counts of received packets by IGMP packet
- // type and a single count of invalid packets received.
+ // PacketsReceived contains statistics about received packets.
PacketsReceived IGMPReceivedPacketStats
+
+ // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats)
}
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
+ // LINT.IfChange(IPStats)
+
// PacketsReceived is the total number of IP packets received from the
// link layer.
PacketsReceived *StatCounter
@@ -1575,7 +1551,7 @@ type IPStats struct {
InvalidSourceAddressesReceived *StatCounter
// PacketsDelivered is the total number of incoming IP packets that
- // are successfully delivered to the transport layer via HandlePacket.
+ // are successfully delivered to the transport layer.
PacketsDelivered *StatCounter
// PacketsSent is the total number of IP packets sent via WritePacket.
@@ -1613,10 +1589,14 @@ type IPStats struct {
// OptionUnknownReceived is the number of unknown IP options seen.
OptionUnknownReceived *StatCounter
+
+ // LINT.ThenChange(network/ip/stats.go:MultiCounterIPStats)
}
// ARPStats collects ARP-specific stats.
type ARPStats struct {
+ // LINT.IfChange(ARPStats)
+
// PacketsReceived is the number of ARP packets received from the link layer.
PacketsReceived *StatCounter
@@ -1644,10 +1624,6 @@ type ARPStats struct {
// ARP request with a bad local address.
OutgoingRequestBadLocalAddressErrors *StatCounter
- // OutgoingRequestNetworkUnreachableErrors is the number of failures to send
- // an ARP request with a network unreachable error.
- OutgoingRequestNetworkUnreachableErrors *StatCounter
-
// OutgoingRequestsDropped is the number of ARP requests which failed to write
// to a link-layer endpoint.
OutgoingRequestsDropped *StatCounter
@@ -1666,6 +1642,8 @@ type ARPStats struct {
// OutgoingRepliesSent is the number of ARP replies successfully written to a
// link-layer endpoint.
OutgoingRepliesSent *StatCounter
+
+ // LINT.ThenChange(network/arp/stats.go:multiCounterARPStats)
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 1742a178d..71695b630 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -7,6 +7,7 @@ go_test(
size = "small",
srcs = [
"forward_test.go",
+ "iptables_test.go",
"link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
@@ -16,6 +17,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/ethernet",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index ac9670f9a..38e1881c7 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -38,96 +38,207 @@ import (
var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil)
var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil)
-// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner
-// link endpoint and checks the destination link address before delivering
-// network packets to the network dispatcher.
-//
-// See ethernet.Endpoint for more details.
-func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck {
- var e endpointWithDestinationCheck
- e.Endpoint.Init(ethernet.New(ep), &e)
- return &e
-}
-
-// endpointWithDestinationCheck is a link endpoint that checks the destination
-// link address before delivering network packets to the network dispatcher.
-type endpointWithDestinationCheck struct {
- nested.Endpoint
-}
-
-// DeliverNetworkPacket implements stack.NetworkDispatcher.
-func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
- if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
- e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
- }
-}
-
-func TestForwarding(t *testing.T) {
- const (
- host1NICID = 1
- routerNICID1 = 2
- routerNICID2 = 3
- host2NICID = 4
-
- listenPort = 8080
- )
+const (
+ host1NICID = 1
+ routerNICID1 = 2
+ routerNICID2 = 3
+ host2NICID = 4
+)
- host1IPv4Addr := tcpip.ProtocolAddress{
+var (
+ host1IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 24,
},
}
- routerNIC1IPv4Addr := tcpip.ProtocolAddress{
+ routerNIC1IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
- routerNIC2IPv4Addr := tcpip.ProtocolAddress{
+ routerNIC2IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()),
PrefixLen: 8,
},
}
- host2IPv4Addr := tcpip.ProtocolAddress{
+ host2IPv4Addr = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("10.0.0.2").To4()),
PrefixLen: 8,
},
}
- host1IPv6Addr := tcpip.ProtocolAddress{
+ host1IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::2").To16()),
PrefixLen: 64,
},
}
- routerNIC1IPv6Addr := tcpip.ProtocolAddress{
+ routerNIC1IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::1").To16()),
PrefixLen: 64,
},
}
- routerNIC2IPv6Addr := tcpip.ProtocolAddress{
+ routerNIC2IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("b::1").To16()),
PrefixLen: 64,
},
}
- host2IPv6Addr := tcpip.ProtocolAddress{
+ host2IPv6Addr = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("b::2").To16()),
PrefixLen: 64,
},
}
+)
+
+func setupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) {
+ host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
+ routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
+
+ if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
+ t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
+ }
+ if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
+ t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
+ }
+ if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
+ t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
+ }
+
+ if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ }
+ if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
+ }
+ if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
+ t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
+ }
+ if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ }
+
+ host1Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host1NICID,
+ },
+ {
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ {
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ NIC: host1NICID,
+ },
+ })
+ routerStack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ {
+ Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID1,
+ },
+ {
+ Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ {
+ Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: routerNICID2,
+ },
+ })
+ host2Stack.SetRouteTable([]tcpip.Route{
+ {
+ Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: host2NICID,
+ },
+ {
+ Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ {
+ Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ NIC: host2NICID,
+ },
+ })
+}
+
+// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner
+// link endpoint and checks the destination link address before delivering
+// network packets to the network dispatcher.
+//
+// See ethernet.Endpoint for more details.
+func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck {
+ var e endpointWithDestinationCheck
+ e.Endpoint.Init(ethernet.New(ep), &e)
+ return &e
+}
+
+// endpointWithDestinationCheck is a link endpoint that checks the destination
+// link address before delivering network packets to the network dispatcher.
+type endpointWithDestinationCheck struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt)
+ }
+}
+
+func TestForwarding(t *testing.T) {
+ const listenPort = 8080
type endpointAndAddresses struct {
serverEP tcpip.Endpoint
@@ -229,7 +340,7 @@ func TestForwarding(t *testing.T) {
subTests := []struct {
name string
proto tcpip.TransportProtocolNumber
- expectedConnectErr *tcpip.Error
+ expectedConnectErr tcpip.Error
setupServerSide func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{})
needRemoteAddr bool
}{
@@ -250,7 +361,7 @@ func TestForwarding(t *testing.T) {
{
name: "TCP",
proto: tcp.ProtocolNumber,
- expectedConnectErr: tcpip.ErrConnectStarted,
+ expectedConnectErr: &tcpip.ErrConnectStarted{},
setupServerSide: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) {
t.Helper()
@@ -260,7 +371,7 @@ func TestForwarding(t *testing.T) {
var addr tcpip.FullAddress
for {
newEP, wq, err := ep.Accept(&addr)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
<-ch
continue
}
@@ -294,113 +405,7 @@ func TestForwarding(t *testing.T) {
host1Stack := stack.New(stackOpts)
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
-
- host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
- routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
-
- if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil {
- t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
- }
- if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
- }
- if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil {
- t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
- }
- if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil {
- t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
- }
-
- if err := routerStack.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv4.ProtocolNumber, err)
- }
- if err := routerStack.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err)
- }
-
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv4Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv4Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
- }
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID1, routerNIC1IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID1, routerNIC1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(routerNICID2, routerNIC2IPv6Addr); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %#v): %s", routerNICID2, routerNIC2IPv6Addr, err)
- }
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
- }
-
- host1Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: host1NICID,
- },
- {
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address,
- NIC: host1NICID,
- },
- {
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address,
- NIC: host1NICID,
- },
- })
- routerStack.SetRouteTable([]tcpip.Route{
- {
- Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID1,
- },
- {
- Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID1,
- },
- {
- Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID2,
- },
- {
- Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: routerNICID2,
- },
- })
- host2Stack.SetRouteTable([]tcpip.Route{
- {
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: host2NICID,
- },
- {
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address,
- NIC: host2NICID,
- },
- {
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
- Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address,
- NIC: host2NICID,
- },
- })
+ setupRoutedStacks(t, host1Stack, routerStack, host2Stack)
epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto)
defer epsAndAddrs.serverEP.Close()
@@ -415,8 +420,11 @@ func TestForwarding(t *testing.T) {
t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err)
}
- if err := epsAndAddrs.clientEP.Connect(serverAddr); err != subTest.expectedConnectErr {
- t.Fatalf("got epsAndAddrs.clientEP.Connect(%#v) = %s, want = %s", serverAddr, err, subTest.expectedConnectErr)
+ {
+ err := epsAndAddrs.clientEP.Connect(serverAddr)
+ if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff)
+ }
}
if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil {
t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err)
@@ -436,9 +444,10 @@ func TestForwarding(t *testing.T) {
write := func(ep tcpip.Endpoint, data []byte) {
t.Helper()
- dataPayload := tcpip.SlicePayload(data)
+ var r bytes.Reader
+ r.Reset(data)
var wOpts tcpip.WriteOptions
- n, err := ep.Write(dataPayload, wOpts)
+ n, err := ep.Write(&r, wOpts)
if err != nil {
t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
}
@@ -486,7 +495,7 @@ func TestForwarding(t *testing.T) {
read(serverCH, serverEP, data, clientAddr)
- data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
write(serverEP, data)
read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
})
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
new file mode 100644
index 000000000..21a8dd291
--- /dev/null
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -0,0 +1,336 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type inputIfNameMatcher struct {
+ name string
+}
+
+var _ stack.Matcher = (*inputIfNameMatcher)(nil)
+
+func (*inputIfNameMatcher) Name() string {
+ return "inputIfNameMatcher"
+}
+
+func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) {
+ return (hook == stack.Input && im.name != "" && im.name == inNicName), false
+}
+
+const (
+ nicID = 1
+ nicName = "nic1"
+ anotherNicName = "nic2"
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ srcAddrV4 = "\x0a\x00\x00\x01"
+ dstAddrV4 = "\x0a\x00\x00\x02"
+ srcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ payloadSize = 20
+)
+
+func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol},
+ })
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, dstAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, dstAddrV6, err)
+ }
+ return s, e
+}
+
+func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) {
+ t.Helper()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ })
+ e := channel.New(0, header.IPv4MinimumMTU, linkAddr)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, dstAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, dstAddrV4, err)
+ }
+ return s, e
+}
+
+func genPacketV6() *stack.PacketBuffer {
+ pktSize := header.IPv6MinimumSize + payloadSize
+ hdr := buffer.NewPrependable(pktSize)
+ ip := header.IPv6(hdr.Prepend(pktSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: payloadSize,
+ TransportProtocol: 99,
+ HopLimit: 255,
+ SrcAddr: srcAddrV6,
+ DstAddr: dstAddrV6,
+ })
+ vv := hdr.View().ToVectorisedView()
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
+}
+
+func genPacketV4() *stack.PacketBuffer {
+ pktSize := header.IPv4MinimumSize + payloadSize
+ hdr := buffer.NewPrependable(pktSize)
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&header.IPv4Fields{
+ TOS: 0,
+ TotalLength: uint16(pktSize),
+ ID: 1,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: 48,
+ Protocol: 99,
+ SrcAddr: srcAddrV4,
+ DstAddr: dstAddrV4,
+ })
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ vv := hdr.View().ToVectorisedView()
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
+}
+
+func TestIPTablesStatsForInput(t *testing.T) {
+ tests := []struct {
+ name string
+ setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint)
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func() *stack.PacketBuffer
+ proto tcpip.NetworkProtocolNumber
+ expectReceived int
+ expectInputDropped int
+ }{
+ {
+ name: "IPv6 Accept",
+ setupStack: genStackV6,
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept",
+ setupStack: genStackV4,
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop (input interface matches)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv4 Drop (input interface matches)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv6 Accept (input interface does not match)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept (input interface does not match)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName}
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop (input interface does not match but invert is true)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ InputInterface: anotherNicName,
+ InputInterfaceInvert: true,
+ }
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv4 Drop (input interface does not match but invert is true)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{
+ InputInterface: anotherNicName,
+ InputInterfaceInvert: true,
+ }
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 1,
+ },
+ {
+ name: "IPv6 Accept (input interface does not match using a matcher)",
+ setupStack: genStackV6,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ }
+ },
+ genPacket: genPacketV6,
+ proto: header.IPv6ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ {
+ name: "IPv4 Accept (input interface does not match using a matcher)",
+ setupStack: genStackV4,
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
+ ruleIdx := filter.BuiltinChains[stack.Input]
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ }
+ },
+ genPacket: genPacketV4,
+ proto: header.IPv4ProtocolNumber,
+ expectReceived: 1,
+ expectInputDropped: 0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, e := test.setupStack(t)
+ test.setupFilter(t, s)
+ e.InjectInbound(test.proto, test.genPacket())
+
+ if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived {
+ t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived)
+ }
+ if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped {
+ t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index af32d3009..7069352f2 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -19,11 +19,14 @@ import (
"fmt"
"net"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
@@ -32,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -207,8 +211,10 @@ func TestPing(t *testing.T) {
defer ep.Close()
icmpBuf := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(icmpBuf)
wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
} else if want := int64(len(icmpBuf)); n != want {
t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
@@ -251,7 +257,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
name string
netProto tcpip.NetworkProtocolNumber
remoteAddr tcpip.Address
- expectedWriteErr *tcpip.Error
+ expectedWriteErr tcpip.Error
sockError tcpip.SockError
}{
{
@@ -270,9 +276,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
name: "IPv4 without resolvable remote",
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
- expectedWriteErr: tcpip.ErrNoRoute,
+ expectedWriteErr: &tcpip.ErrNoRoute{},
sockError: tcpip.SockError{
- Err: tcpip.ErrNoRoute,
+ Err: &tcpip.ErrNoRoute{},
ErrType: byte(header.ICMPv4DstUnreachable),
ErrCode: byte(header.ICMPv4HostUnreachable),
ErrOrigin: tcpip.SockExtErrorOriginICMP,
@@ -292,9 +298,9 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
name: "IPv6 without resolvable remote",
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
- expectedWriteErr: tcpip.ErrNoRoute,
+ expectedWriteErr: &tcpip.ErrNoRoute{},
sockError: tcpip.SockError{
- Err: tcpip.ErrNoRoute,
+ Err: &tcpip.ErrNoRoute{},
ErrType: byte(header.ICMPv6DstUnreachable),
ErrCode: byte(header.ICMPv6AddressUnreachable),
ErrOrigin: tcpip.SockExtErrorOriginICMP6,
@@ -351,16 +357,24 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
remoteAddr := listenerAddr
remoteAddr.Addr = test.remoteAddr
- if err := clientEP.Connect(remoteAddr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, tcpip.ErrConnectStarted)
+ {
+ err := clientEP.Connect(remoteAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", remoteAddr, err, &tcpip.ErrConnectStarted{})
+ }
}
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
<-ch
- var wOpts tcpip.WriteOptions
- if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr {
- t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
+ {
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ _, err := clientEP.Write(&r, wOpts)
+ if diff := cmp.Diff(test.expectedWriteErr, err); diff != "" {
+ t.Errorf("unexpected error from clientEP.Write(_, %#v), (-want, +got):\n%s", wOpts, diff)
+ }
}
if test.expectedWriteErr == nil {
@@ -374,7 +388,7 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
sockErrCmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(tcpip.SockError{}),
- cmp.Comparer(func(a, b *tcpip.Error) bool {
+ cmp.Comparer(func(a, b tcpip.Error) bool {
// tcpip.Error holds an unexported field but the errors netstack uses
// are pre defined so we can simply compare pointers.
return a == b
@@ -404,20 +418,134 @@ func TestGetLinkAddress(t *testing.T) {
)
tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedLinkAddr bool
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedOk bool
}{
{
- name: "IPv4",
+ name: "IPv4 resolvable",
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
},
{
- name: "IPv6",
+ name: "IPv6 resolvable",
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, useNeighborCache := range []bool{true, false} {
+ t.Run(fmt.Sprintf("UseNeighborCache=%t", useNeighborCache), func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ UseNeighborCache: useNeighborCache,
+ }
+
+ host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
+ }
+ wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
+ if test.expectedOk {
+ wantRes.LinkAddress = linkAddr2
+ }
+ if diff := cmp.Diff(wantRes, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestRouteResolvedFields(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ localAddr tcpip.Address
+ remoteAddr tcpip.Address
+ immediatelyResolvable bool
+ expectedSuccess bool
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 immediately resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: header.IPv4AllSystems,
+ immediatelyResolvable: true,
+ expectedSuccess: true,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv4Address(header.IPv4AllSystems),
+ },
+ {
+ name: "IPv6 immediately resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: header.IPv6AllNodesMulticastAddress,
+ immediatelyResolvable: true,
+ expectedSuccess: true,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
+ {
+ name: "IPv4 resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: true,
+ expectedLinkAddr: linkAddr2,
+ },
+ {
+ name: "IPv6 resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: true,
+ expectedLinkAddr: linkAddr2,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ localAddr: ipv4Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ localAddr: ipv6Addr1.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ immediatelyResolvable: false,
+ expectedSuccess: false,
},
}
@@ -431,28 +559,618 @@ func TestGetLinkAddress(t *testing.T) {
}
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ var wantRouteInfo stack.RouteInfo
+ wantRouteInfo.LocalLinkAddress = linkAddr1
+ wantRouteInfo.LocalAddress = test.localAddr
+ wantRouteInfo.RemoteAddress = test.remoteAddr
+ wantRouteInfo.NetProto = test.netProto
+ wantRouteInfo.Loop = stack.PacketOut
+ wantRouteInfo.RemoteLinkAddress = test.expectedLinkAddr
+
+ ch := make(chan stack.ResolvedFieldsResult, 1)
+
+ if !test.immediatelyResolvable {
+ wantUnresolvedRouteInfo := wantRouteInfo
+ wantUnresolvedRouteInfo.RemoteLinkAddress = ""
- for i := 0; i < 2; i++ {
- addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {})
- var want *tcpip.Error
- if i == 0 {
- want = tcpip.ErrWouldBlock
+ err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
- if err != want {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want)
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: test.expectedSuccess}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result mismatch (-want +got):\n%s", diff)
}
- if i == 0 {
- <-ch
- continue
+ if !test.expectedSuccess {
+ return
}
- if addr != linkAddr2 {
- t.Fatalf("got addr = %s, want = %s", addr, linkAddr2)
+ // At this point the neighbor table should be populated so the route
+ // should be immediately resolvable.
+ }
+
+ if err := r.ResolvedFields(func(r stack.ResolvedFieldsResult) {
+ ch <- r
+ }); err != nil {
+ t.Errorf("r.ResolvedFields(_): %s", err)
+ }
+ select {
+ case routeResolveRes := <-ch:
+ if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Success: true}, routeResolveRes, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" {
+ t.Errorf("route resolve result from resolved route mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected route to be immediately resolvable")
}
})
}
})
}
}
+
+func TestWritePacketsLinkResolution(t *testing.T) {
+ const (
+ host1NICID = 1
+ host2NICID = 4
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedWriteErr tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ }
+
+ host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID)
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ serverEP, err := host2Stack.NewEndpoint(udp.ProtocolNumber, test.netProto, &serverWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err)
+ }
+ defer serverEP.Close()
+
+ serverAddr := tcpip.FullAddress{Port: 1234}
+ if err := serverEP.Bind(serverAddr); err != nil {
+ t.Fatalf("serverEP.Bind(%#v): %s", serverAddr, err)
+ }
+
+ r, err := host1Stack.FindRoute(host1NICID, "", test.remoteAddr, test.netProto, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("host1Stack.FindRoute(%d, '', %s, %d, false): %s", host1NICID, test.remoteAddr, test.netProto, err)
+ }
+ defer r.Release()
+
+ data := []byte{1, 2}
+ var pkts stack.PacketBufferList
+ for _, d := range data {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: buffer.View([]byte{d}).ToVectorisedView(),
+ })
+ pkt.TransportProtocolNumber = udp.ProtocolNumber
+ length := uint16(pkt.Size())
+ udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: serverAddr.Port,
+ Length: length,
+ })
+ xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+
+ pkts.PushBack(pkt)
+ }
+
+ params := stack.NetworkHeaderParams{
+ Protocol: udp.ProtocolNumber,
+ TTL: 64,
+ TOS: stack.DefaultTOS,
+ }
+
+ if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil {
+ t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err)
+ } else if want := pkts.Len(); want != n {
+ t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want)
+ }
+
+ var writer bytes.Buffer
+ count := 0
+ for {
+ var rOpts tcpip.ReadOptions
+ res, err := serverEP.Read(&writer, rOpts)
+ if err != nil {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
+ // Should not have anymore bytes to read after we read the sent
+ // number of bytes.
+ if count == len(data) {
+ break
+ }
+
+ <-serverCH
+ continue
+ }
+
+ t.Fatalf("serverEP.Read(_, %#v): %s", rOpts, err)
+ }
+ count += res.Count
+ }
+
+ if got, want := host2Stack.Stats().UDP.PacketsReceived.Value(), uint64(len(data)); got != want {
+ t.Errorf("got host2Stack.Stats().UDP.PacketsReceived.Value() = %d, want = %d", got, want)
+ }
+ if diff := cmp.Diff(data, writer.Bytes()); diff != "" {
+ t.Errorf("read bytes mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+type eventType int
+
+const (
+ entryAdded eventType = iota
+ entryChanged
+ entryRemoved
+)
+
+func (t eventType) String() string {
+ switch t {
+ case entryAdded:
+ return "add"
+ case entryChanged:
+ return "change"
+ case entryRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+type eventInfo struct {
+ eventType eventType
+ nicID tcpip.NICID
+ entry stack.NeighborEntry
+}
+
+func (e eventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
+}
+
+var _ stack.NUDDispatcher = (*nudDispatcher)(nil)
+
+type nudDispatcher struct {
+ c chan eventInfo
+}
+
+func (d *nudDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryAdded,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.c <- e
+}
+
+func (d *nudDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryChanged,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.c <- e
+}
+
+func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
+ e := eventInfo{
+ eventType: entryRemoved,
+ nicID: nicID,
+ entry: entry,
+ }
+ d.c <- e
+}
+
+func (d *nudDispatcher) waitForEvent(want eventInfo) error {
+ if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
+ return fmt.Errorf("got invalid event (-want +got):\n%s", diff)
+ }
+ return nil
+}
+
+// TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it
+// that the neighbor used for a route is reachable.
+func TestTCPConfirmNeighborReachability(t *testing.T) {
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ neighborAddr tcpip.Address
+ getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{})
+ isHost1Listener bool
+ }{
+ {
+ name: "IPv4 active connection through neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv6 active connection through neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv4 active connection to neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv6 active connection to neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ },
+ {
+ name: "IPv4 passive connection to neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ {
+ name: "IPv6 passive connection to neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ {
+ name: "IPv4 passive connection through neighbor",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: host1IPv4Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ {
+ name: "IPv6 passive connection through neighbor",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ neighborAddr: routerNIC1IPv6Addr.AddressWithPrefix.Address,
+ getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) {
+ var listenerWQ waiter.Queue
+ listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventOut)
+ clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ)
+ if err != nil {
+ listenerEP.Close()
+ t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err)
+ }
+
+ return listenerEP, clientEP, clientCH
+ },
+ isHost1Listener: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ nudDisp := nudDispatcher{
+ c: make(chan eventInfo, 3),
+ }
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ Clock: clock,
+ UseNeighborCache: true,
+ }
+ host1StackOpts := stackOpts
+ host1StackOpts.NUDDisp = &nudDisp
+
+ host1Stack := stack.New(host1StackOpts)
+ routerStack := stack.New(stackOpts)
+ host2Stack := stack.New(stackOpts)
+ setupRoutedStacks(t, host1Stack, routerStack, host2Stack)
+
+ // Add a reachable dynamic entry to our neighbor table for the remote.
+ {
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := host1Stack.GetLinkAddress(host1NICID, test.neighborAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.neighborAddr, test.netProto, err, &tcpip.ErrWouldBlock{})
+ }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: linkAddr2, Success: true}, <-ch); diff != "" {
+ t.Fatalf("link resolution mismatch (-want +got):\n%s", diff)
+ }
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryAdded,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr},
+ }); err != nil {
+ t.Fatalf("error waiting for initial NUD event: %s", err)
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for reachable NUD event: %s", err)
+ }
+
+ // Wait for the remote's neighbor entry to be stale before creating a
+ // TCP connection from host1 to some remote.
+ nudConfigs, err := host1Stack.NUDConfigurations(host1NICID)
+ if err != nil {
+ t.Fatalf("host1Stack.NUDConfigurations(%d): %s", host1NICID, err)
+ }
+ // The maximum reachable time for a neighbor is some maximum random factor
+ // applied to the base reachable time.
+ //
+ // See NUDConfigurations.BaseReachableTime for more information.
+ maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor)
+ clock.Advance(maxReachableTime)
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for stale NUD event: %s", err)
+ }
+
+ listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack)
+ defer listenerEP.Close()
+ defer clientEP.Close()
+ listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234}
+ if err := listenerEP.Bind(listenerAddr); err != nil {
+ t.Fatalf("listenerEP.Bind(%#v): %s", listenerAddr, err)
+ }
+ if err := listenerEP.Listen(1); err != nil {
+ t.Fatalf("listenerEP.Listen(1): %s", err)
+ }
+ {
+ err := clientEP.Connect(listenerAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got clientEP.Connect(%#v) = %s, want = %s", listenerAddr, err, &tcpip.ErrConnectStarted{})
+ }
+ }
+
+ // Wait for the TCP handshake to complete then make sure the neighbor is
+ // reachable without entering the probe state as TCP should provide NUD
+ // with confirmation that the neighbor is reachable (indicated by a
+ // successful 3-way handshake).
+ <-clientCH
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for delay NUD event: %s", err)
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for reachable NUD event: %s", err)
+ }
+
+ // Wait for the neighbor to be stale again then send data to the remote.
+ //
+ // On successful transmission, the neighbor should become reachable
+ // without probing the neighbor as a TCP ACK would be received which is an
+ // indication of the neighbor being reachable.
+ clock.Advance(maxReachableTime)
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for stale NUD event: %s", err)
+ }
+ var r bytes.Reader
+ r.Reset([]byte{0})
+ var wOpts tcpip.WriteOptions
+ if _, err := clientEP.Write(&r, wOpts); err != nil {
+ t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err)
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for delay NUD event: %s", err)
+ }
+ if test.isHost1Listener {
+ // If host1 is not the client, host1 does not send any data so TCP
+ // has no way to know it is making forward progress. Because of this,
+ // TCP should not mark the route reachable and NUD should go through the
+ // probe state.
+ clock.Advance(nudConfigs.DelayFirstProbeTime)
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for probe NUD event: %s", err)
+ }
+ }
+ if err := nudDisp.waitForEvent(eventInfo{
+ eventType: entryChanged,
+ nicID: host1NICID,
+ entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: linkAddr2},
+ }); err != nil {
+ t.Fatalf("error waiting for reachable NUD event: %s", err)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 3b13ba04d..ab67762ef 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -37,7 +37,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
type ndpDispatcher struct{}
-func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, *tcpip.Error) {
+func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
}
func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
@@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
Port: localPort,
},
}
- n, err := sep.Write(tcpip.SlicePayload(data), wopts)
+ var r bytes.Reader
+ r.Reset(data)
+ n, err := sep.Write(&r, wopts)
if err != nil {
t.Fatalf("sep.Write(_, _): %s", err)
}
@@ -260,8 +262,8 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
- } else if err != tcpip.ErrWouldBlock {
- t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
+ } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
}
})
}
@@ -320,11 +322,14 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) {
if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil {
t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err)
}
- if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(r.MaxHeaderLength()),
- Data: data.ToVectorisedView(),
- })); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState)
+ {
+ err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ }))
+ if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok {
+ t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{})
+ }
}
}
@@ -468,13 +473,17 @@ func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) {
Addr: test.dstAddr,
Port: localPort,
}
- if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted {
- t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ {
+ err := connectingEndpoint.Connect(connectAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err)
+ }
}
if !test.expectAccept {
- if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ _, _, err := listeningEndpoint.Accept(nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, &tcpip.ErrWouldBlock{})
}
return
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index ce7c16bd1..d685fdd36 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -479,8 +479,8 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
if diff := cmp.Diff(data, buf.Bytes()); diff != "" {
t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
}
- } else if err != tcpip.ErrWouldBlock {
- t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock)
+ } else if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), &tcpip.ErrWouldBlock{})
}
})
}
@@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
- if n, err := wep.ep.Write(data, writeOpts); err != nil {
+ data := []byte{byte(i), 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := wep.ep.Write(&r, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
@@ -759,8 +761,11 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
if err := ep.SetSockOpt(&removeOpt); err != nil {
t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
}
- if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
- t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock)
+ {
+ _, err := ep.Read(&buf, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, &tcpip.ErrWouldBlock{})
+ }
}
})
}
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index b222d2b05..9654c9527 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -81,7 +81,7 @@ func TestLocalPing(t *testing.T) {
linkEndpoint func() stack.LinkEndpoint
localAddr tcpip.Address
icmpBuf func(*testing.T) buffer.View
- expectedConnectErr *tcpip.Error
+ expectedConnectErr tcpip.Error
checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
}{
{
@@ -126,7 +126,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv4.ProtocolNumber,
linkEndpoint: loopback.New,
icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
{
@@ -135,7 +135,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv6.ProtocolNumber,
linkEndpoint: loopback.New,
icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
},
{
@@ -144,7 +144,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv4.ProtocolNumber,
linkEndpoint: channelEP,
icmpBuf: ipv4ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: channelEPCheck,
},
{
@@ -153,7 +153,7 @@ func TestLocalPing(t *testing.T) {
netProto: ipv6.ProtocolNumber,
linkEndpoint: channelEP,
icmpBuf: ipv6ICMPBuf,
- expectedConnectErr: tcpip.ErrNoRoute,
+ expectedConnectErr: &tcpip.ErrNoRoute{},
checkLinkEndpoint: channelEPCheck,
},
}
@@ -186,17 +186,22 @@ func TestLocalPing(t *testing.T) {
defer ep.Close()
connAddr := tcpip.FullAddress{Addr: test.localAddr}
- if err := ep.Connect(connAddr); err != test.expectedConnectErr {
- t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ {
+ err := ep.Connect(connAddr)
+ if diff := cmp.Diff(test.expectedConnectErr, err); diff != "" {
+ t.Fatalf("unexpected error from ep.Connect(%#v), (-want, +got):\n%s", connAddr, diff)
+ }
}
if test.expectedConnectErr != nil {
return
}
- payload := tcpip.SlicePayload(test.icmpBuf(t))
+ payload := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(payload)
var wOpts tcpip.WriteOptions
- if n, err := ep.Write(payload, wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
} else if n != int64(len(payload)) {
t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
@@ -261,12 +266,12 @@ func TestLocalUDP(t *testing.T) {
subTests := []struct {
name string
addAddress bool
- expectedWriteErr *tcpip.Error
+ expectedWriteErr tcpip.Error
}{
{
name: "Unassigned local address",
addAddress: false,
- expectedWriteErr: tcpip.ErrNoRoute,
+ expectedWriteErr: &tcpip.ErrNoRoute{},
},
{
name: "Assigned local address",
@@ -329,12 +334,14 @@ func TestLocalUDP(t *testing.T) {
Port: 80,
}
- clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ clientPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(clientPayload)
wOpts := tcpip.WriteOptions{
To: &serverAddr,
}
- if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
} else if subTest.expectedWriteErr != nil {
// Nothing else to test if we expected not to be able to send the
@@ -376,12 +383,14 @@ func TestLocalUDP(t *testing.T) {
}
}
- serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ serverPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(serverPayload)
wOpts := tcpip.WriteOptions{
To: &clientAddr,
}
- if n, err := server.Write(serverPayload, wOpts); err != nil {
+ if n, err := server.Write(&r, wOpts); err != nil {
t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
} else if n != int64(len(serverPayload)) {
t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 256e19296..3cf05520d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -69,8 +69,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
+ mu sync.RWMutex `state:"nosave"`
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
state endpointState
@@ -85,7 +84,7 @@ type endpoint struct {
ops tcpip.SocketOptions
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
@@ -94,11 +93,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
state: stateInitial,
uniqueID: s.UniqueID(),
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.SetSendBufferSize(32*1024, false /* notify */)
+
+ // Override with stack defaults.
+ var ss tcpip.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
+ }
return ep, nil
}
@@ -119,7 +124,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -154,14 +159,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if e.rcvClosed {
e.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
e.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -188,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
@@ -199,7 +204,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.state {
case stateInitial:
case stateConnected:
@@ -207,11 +212,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
case stateBound:
if to == nil {
- return false, tcpip.ErrDestinationRequired
+ return false, &tcpip.ErrDestinationRequired{}
}
return false, nil
default:
- return false, tcpip.ErrInvalidEndpointState
+ return false, &tcpip.ErrInvalidEndpointState{}
}
e.mu.RUnlock()
@@ -236,18 +241,18 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
n, err := e.write(p, opts)
- switch err {
+ switch err.(type) {
case nil:
e.stats.PacketsSent.Increment()
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
e.stats.WriteErrors.InvalidArgs.Increment()
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
e.stats.WriteErrors.WriteClosed.Increment()
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
default:
@@ -257,10 +262,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
to := opts.To
@@ -270,7 +275,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
@@ -292,7 +297,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
nicID := to.NIC
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
@@ -313,11 +318,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
route = r
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
+ var err tcpip.Error
switch e.NetProto {
case header.IPv4ProtocolNumber:
err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
@@ -334,12 +340,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
}
// SetSockOptInt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.TTLOption:
e.mu.Lock()
@@ -351,7 +357,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -362,11 +368,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- e.mu.Lock()
- v := e.sndBufSize
- e.mu.Unlock()
- return v, nil
case tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
@@ -381,18 +382,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -410,7 +411,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
// Linux performs these basic checks.
if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
icmpv4.SetChecksum(0)
@@ -424,9 +425,9 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
-func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
+func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -441,7 +442,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
data = data[header.ICMPv6MinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
dataVV := data.ToVectorisedView()
@@ -456,7 +457,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -465,12 +466,12 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -485,12 +486,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
if nicID != 0 && nicID != e.BindNICID {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
nicID = e.BindNICID
default:
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -535,19 +536,19 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
e.shutdownFlags |= flags
if e.state != stateConnected {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
if flags&tcpip.ShutdownRead != 0 {
@@ -565,31 +566,31 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen is not supported by UDP, it just fails.
-func (*endpoint) Listen(int) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Listen(int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
- switch err {
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
+ switch err.(type) {
case nil:
return true, nil
- case tcpip.ErrPortInUse:
+ case *tcpip.ErrPortInUse:
return false, nil
default:
return false, err
@@ -599,11 +600,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
return id, err
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
if e.state != stateInitial {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -619,7 +620,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
if len(addr.Addr) != 0 {
// A local address was specified, verify that it's valid.
if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
}
@@ -647,7 +648,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
-func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -663,7 +664,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -675,12 +676,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
return tcpip.FullAddress{
@@ -805,7 +806,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
func (*endpoint) Wait() {}
// LastError implements tcpip.Endpoint.LastError.
-func (*endpoint) LastError() *tcpip.Error {
+func (*endpoint) LastError() tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
index 9d263c0ec..c9fa9974a 100644
--- a/pkg/tcpip/transport/icmp/endpoint_state.go
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -69,12 +69,13 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
if e.state != stateBound && e.state != stateConnected {
return
}
- var err *tcpip.Error
+ var err tcpip.Error
if e.state == stateConnected {
e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */)
if err != nil {
@@ -84,7 +85,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.ID.LocalAddress = e.route.LocalAddress
} else if len(e.ID.LocalAddress) != 0 { // stateBound
if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
+ panic(&tcpip.ErrBadLocalAddress{})
}
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 3820e5dc7..47f7dd1cb 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -59,18 +59,18 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
if netProto != p.netProto() {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
if netProto != p.netProto() {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
@@ -87,7 +87,7 @@ func (p *protocol) MinimumPacketSize() int {
}
// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil.
-func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
switch p.number {
case ProtocolNumber4:
hdr := header.ICMPv4(v)
@@ -106,13 +106,13 @@ func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stac
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index c0d6fb442..73bb66830 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -79,24 +79,22 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
- bound bool
- boundNIC tcpip.NICID
+ mu sync.RWMutex `state:"nosave"`
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
// lastErrorMu protects lastError.
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error `state:".(string)"`
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError tcpip.Error
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
// NewEndpoint returns a new packet endpoint.
-func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
ep := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
@@ -106,14 +104,13 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
netProto: netProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
// Override with stack defaults.
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
- ep.sndBufSizeMax = ss.Default
+ ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs stack.ReceiveBufferSizeOption
@@ -162,16 +159,16 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
// endpoint is closed.
if ep.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if ep.rcvClosed {
ep.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
ep.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -201,49 +198,49 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
n, err := packet.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
}
-func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
-// connected, and this function always returnes tcpip.ErrNotSupported.
-func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- return tcpip.ErrNotSupported
+// connected, and this function always returnes *tcpip.ErrNotSupported.
+func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
-// with Shutdown, and this function always returns tcpip.ErrNotSupported.
-func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- return tcpip.ErrNotSupported
+// with Shutdown, and this function always returns *tcpip.ErrNotSupported.
+func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
-// Listen, and this function always returns tcpip.ErrNotSupported.
-func (*endpoint) Listen(backlog int) *tcpip.Error {
- return tcpip.ErrNotSupported
+// Listen, and this function always returns *tcpip.ErrNotSupported.
+func (*endpoint) Listen(backlog int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
-// Accept, and this function always returns tcpip.ErrNotSupported.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+// Accept, and this function always returns *tcpip.ErrNotSupported.
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
// Bind implements tcpip.Endpoint.Bind.
-func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
// TODO(gvisor.dev/issue/173): Add Bind support.
// "By default, all packets of the specified protocol type are passed
@@ -277,14 +274,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, tcpip.ErrNotSupported
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
// Even a connected socket doesn't return a remote address.
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness implements tcpip.Endpoint.Readiness.
@@ -306,38 +303,20 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
// used with SetSockOpt, and this function always returns
-// tcpip.ErrNotSupported.
-func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+// *tcpip.ErrNotSupported.
+func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss stack.SendBufferSizeOption
- if err := ep.stack.Option(&ss); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
- }
- if v > ss.Max {
- v = ss.Max
- }
- if v < ss.Min {
- v = ss.Min
- }
- ep.mu.Lock()
- ep.sndBufSizeMax = v
- ep.mu.Unlock()
- return nil
-
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -357,11 +336,11 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
-func (ep *endpoint) LastError() *tcpip.Error {
+func (ep *endpoint) LastError() tcpip.Error {
ep.lastErrorMu.Lock()
defer ep.lastErrorMu.Unlock()
@@ -371,19 +350,19 @@ func (ep *endpoint) LastError() *tcpip.Error {
}
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
-func (ep *endpoint) UpdateLastError(err *tcpip.Error) {
+func (ep *endpoint) UpdateLastError(err tcpip.Error) {
ep.lastErrorMu.Lock()
ep.lastError = err
ep.lastErrorMu.Unlock()
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -395,12 +374,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
ep.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- v := ep.sndBufSizeMax
- ep.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
ep.rcvMu.Lock()
v := ep.rcvBufSizeMax
@@ -408,7 +381,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index e2fa96d17..ece662c0d 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -63,29 +63,11 @@ func (ep *endpoint) loadRcvBufSizeMax(max int) {
// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad() {
- // StackFromEnv is a stack used specifically for save/restore.
ep.stack = stack.StackFromEnv
+ ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits)
// TODO(gvisor.dev/173): Once bind is supported, choose the right NIC.
if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil {
- panic(*err)
+ panic(err)
}
}
-
-// saveLastError is invoked by stateify.
-func (ep *endpoint) saveLastError() string {
- if ep.lastError == nil {
- return ""
- }
-
- return ep.lastError.String()
-}
-
-// loadLastError is invoked by stateify.
-func (ep *endpoint) loadLastError(s string) {
- if s == "" {
- return
- }
-
- ep.lastError = tcpip.StringToError(s)
-}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ae743f75e..9c9ccc0ff 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -76,12 +76,10 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route *stack.Route `state:"manual"`
@@ -95,13 +93,13 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
}
-func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
+func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, tcpip.ErrUnknownProtocol
+ return nil, &tcpip.ErrUnknownProtocol{}
}
e := &endpoint{
@@ -112,16 +110,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
associated: associated,
}
- e.ops.InitHandler(e)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
e.ops.SetHeaderIncluded(!associated)
+ e.ops.SetSendBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
- e.sndBufSizeMax = ss.Default
+ e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs stack.ReceiveBufferSizeOption
@@ -138,7 +136,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
return nil, err
}
@@ -159,7 +157,7 @@ func (e *endpoint) Close() {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -191,16 +189,16 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
// endpoint is closed.
if e.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if e.rcvClosed {
e.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
e.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -227,37 +225,37 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
}
// Write implements tcpip.Endpoint.Write.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// We can create, but not write to, unassociated IPv6 endpoints.
if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
}
n, err := e.write(p, opts)
- switch err {
+ switch err.(type) {
case nil:
e.stats.PacketsSent.Increment()
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
e.stats.WriteErrors.InvalidArgs.Increment()
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
e.stats.WriteErrors.WriteClosed.Increment()
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
default:
@@ -267,22 +265,22 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.closed {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
- payloadBytes, err := p.FullPayload()
- if err != nil {
- return 0, err
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
// If this is an unassociated socket and callee provided a nonzero
@@ -290,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
dstAddr := ip.DestinationAddress()
// Update dstAddr with the address in the IP header, unless
@@ -311,7 +309,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
- return 0, tcpip.ErrDestinationRequired
+ return 0, &tcpip.ErrDestinationRequired{}
}
return e.finishWrite(payloadBytes, e.route)
@@ -321,7 +319,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
// Find the route to the destination. If BindAddress is 0,
@@ -338,7 +336,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, *tcpip.Error) {
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) {
if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
@@ -365,22 +363,22 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect implements tcpip.Endpoint.Connect.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
- return tcpip.ErrAddressFamilyNotSupported
+ return &tcpip.ErrAddressFamilyNotSupported{}
}
e.mu.Lock()
defer e.mu.Unlock()
if e.closed {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
nic := addr.NIC
@@ -395,7 +393,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
} else if addr.NIC != e.BindNICID {
// We're bound and addr specifies a NIC. They must be
// the same.
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
}
@@ -407,15 +405,18 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
route.Release()
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
e.RegisterNICID = nic
}
- // Save the route we've connected via.
+ if e.route != nil {
+ // If the endpoint was previously connected then release any previous route.
+ e.route.Release()
+ }
e.route = route
e.connected = true
@@ -423,42 +424,42 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
if !e.connected {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
return nil
}
// Listen implements tcpip.Endpoint.Listen.
-func (*endpoint) Listen(backlog int) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Listen(backlog int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept implements tcpip.Endpoint.Accept.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
// Bind implements tcpip.Endpoint.Bind.
-func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
- return tcpip.ErrBadLocalAddress
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
+ return &tcpip.ErrBadLocalAddress{}
}
if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
e.RegisterNICID = addr.NIC
e.BindNICID = addr.NIC
}
@@ -470,14 +471,14 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- return tcpip.FullAddress{}, tcpip.ErrNotSupported
+func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
-func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
// Even a connected socket doesn't return a remote address.
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
// Readiness implements tcpip.Endpoint.Readiness.
@@ -498,37 +499,19 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss stack.SendBufferSizeOption
- if err := e.stack.Option(&ss); err != nil {
- panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
- }
- if v > ss.Max {
- v = ss.Max
- }
- if v < ss.Min {
- v = ss.Min
- }
- e.mu.Lock()
- e.sndBufSizeMax = v
- e.mu.Unlock()
- return nil
-
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -548,17 +531,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -570,12 +553,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- e.mu.Lock()
- v := e.sndBufSizeMax
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
v := e.rcvBufSizeMax
@@ -583,7 +560,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
@@ -703,7 +680,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
func (*endpoint) Wait() {}
// LastError implements tcpip.Endpoint.LastError.
-func (*endpoint) LastError() *tcpip.Error {
+func (*endpoint) LastError() tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 4a7e1c039..263ec5146 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -69,10 +69,11 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
// If the endpoint is connected, re-connect.
if e.connected {
- var err *tcpip.Error
+ var err tcpip.Error
// TODO(gvisor.dev/issue/4906): Properly restore the route with the right
// remote address. We used to pass e.remote.RemoteAddress which was
// effectively the empty address but since moving e.route to hold a pointer
@@ -88,12 +89,12 @@ func (e *endpoint) Resume(s *stack.Stack) {
// If the endpoint is bound, re-bind.
if e.bound {
if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(tcpip.ErrBadLocalAddress)
+ panic(&tcpip.ErrBadLocalAddress{})
}
}
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index f30aa2a4a..e393b993d 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -25,11 +25,11 @@ import (
type EndpointFactory struct{}
// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint.
-func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint.
-func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return packet.NewEndpoint(stack, cooked, netProto, waiterQueue)
}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 7e81203ba..fcdd032c5 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -99,7 +99,6 @@ go_test(
"//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6921de0f1..842c1622b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -199,7 +199,7 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
@@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// On success, a handshake h is returned with h.ep.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -267,7 +267,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
ep.mu.Unlock()
ep.Close()
- return nil, tcpip.ErrConnectionAborted
+ return nil, &tcpip.ErrConnectionAborted{}
}
l.addPendingEndpoint(ep)
@@ -281,14 +281,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
l.removePendingEndpoint(ep)
- return nil, tcpip.ErrConnectionAborted
+ return nil, &tcpip.ErrConnectionAborted{}
}
deferAccept = l.listenEP.deferAccept
}
// Register new endpoint so that packets are routed to it.
- if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
ep.mu.Unlock()
ep.Close()
@@ -313,7 +313,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// established endpoint is returned with e.mu held.
//
// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
-func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, tcpip.Error) {
h, err := l.startHandshake(s, opts, queue, owner)
if err != nil {
return nil, err
@@ -467,7 +467,7 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) tcpip.Error {
defer s.decRef()
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
@@ -522,7 +522,7 @@ func (e *endpoint) acceptQueueIsFull() bool {
// and needs to handle it.
//
// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
-func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
e.rcvListMu.Unlock()
@@ -692,7 +692,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
}
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
n.mu.Unlock()
n.Close()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index f711cd4df..4695b66d6 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -226,7 +226,7 @@ func (h *handshake) checkAck(s *segment) bool {
// synSentState handles a segment received when the TCP 3-way handshake is in
// the SYN-SENT state.
-func (h *handshake) synSentState(s *segment) *tcpip.Error {
+func (h *handshake) synSentState(s *segment) tcpip.Error {
// RFC 793, page 37, states that in the SYN-SENT state, a reset is
// acceptable if the ack field acknowledges the SYN.
if s.flagIsSet(header.TCPFlagRst) {
@@ -237,7 +237,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
h.ep.workerCleanup = true
// Although the RFC above calls out ECONNRESET, Linux actually returns
// ECONNREFUSED here so we do as well.
- return tcpip.ErrConnectionRefused
+ return &tcpip.ErrConnectionRefused{}
}
return nil
}
@@ -314,12 +314,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// synRcvdState handles a segment received when the TCP 3-way handshake is in
// the SYN-RCVD state.
-func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+func (h *handshake) synRcvdState(s *segment) tcpip.Error {
if s.flagIsSet(header.TCPFlagRst) {
// RFC 793, page 37, states that in the SYN-RCVD state, a reset
// is acceptable if the sequence number is in the window.
if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
- return tcpip.ErrConnectionRefused
+ return &tcpip.ErrConnectionRefused{}
}
return nil
}
@@ -349,7 +349,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
if !h.active {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
h.resetState()
@@ -412,7 +412,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
-func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+func (h *handshake) handleSegment(s *segment) tcpip.Error {
h.sndWnd = s.window
if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
h.sndWnd <<= uint8(h.sndWndScale)
@@ -429,7 +429,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error {
// processSegments goes through the segment queue and processes up to
// maxSegmentsPerWake (if they're available).
-func (h *handshake) processSegments() *tcpip.Error {
+func (h *handshake) processSegments() tcpip.Error {
for i := 0; i < maxSegmentsPerWake; i++ {
s := h.ep.segmentQueue.dequeue()
if s == nil {
@@ -505,7 +505,7 @@ func (h *handshake) start() {
}
// complete completes the TCP 3-way handshake initiated by h.start().
-func (h *handshake) complete() *tcpip.Error {
+func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
resendWaker := sleep.Waker{}
@@ -555,7 +555,7 @@ func (h *handshake) complete() *tcpip.Error {
case wakerForNotification:
n := h.ep.fetchNotifications()
if (n&notifyClose)|(n&notifyAbort) != 0 {
- return tcpip.ErrAborted
+ return &tcpip.ErrAborted{}
}
if n&notifyDrain != 0 {
for !h.ep.segmentQueue.empty() {
@@ -593,19 +593,19 @@ type backoffTimer struct {
t *time.Timer
}
-func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) {
if timeout > maxTimeout {
- return nil, tcpip.ErrTimeout
+ return nil, &tcpip.ErrTimeout{}
}
bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
bt.t = time.AfterFunc(timeout, f)
return bt, nil
}
-func (bt *backoffTimer) reset() *tcpip.Error {
+func (bt *backoffTimer) reset() tcpip.Error {
bt.timeout *= 2
if bt.timeout > MaxRTO {
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
bt.t.Reset(bt.timeout)
return nil
@@ -706,7 +706,7 @@ type tcpFields struct {
txHash uint32
}
-func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error {
+func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error {
tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
@@ -716,7 +716,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp
return nil
}
-func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error {
tf.txHash = e.txHash
if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
@@ -755,7 +755,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
}
}
-func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
// We need to shallow clone the VectorisedView here as ReadToView will
// split the VectorisedView and Trim underlying views as it splits. Not
// doing the clone here will cause the underlying views of data itself
@@ -803,7 +803,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error {
optLen := len(tf.opts)
if tf.rcvWnd > math.MaxUint16 {
tf.rcvWnd = math.MaxUint16
@@ -875,7 +875,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
}
// sendRaw sends a TCP segment to the endpoint's peer.
-func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
var sackBlocks []header.SACKBlock
if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
@@ -895,55 +895,60 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn
return err
}
-func (e *endpoint) handleWrite() *tcpip.Error {
- // Move packets from send queue to send list. The queue is accessible
- // from other goroutines and protected by the send mutex, while the send
- // list is only accessible from the handler goroutine, so it needs no
- // mutexes.
+func (e *endpoint) handleWrite() {
e.sndBufMu.Lock()
+ next := e.drainSendQueueLocked()
+ e.sndBufMu.Unlock()
+
+ e.sendData(next)
+}
+// Move packets from send queue to send list.
+//
+// Precondition: e.sndBufMu must be locked.
+func (e *endpoint) drainSendQueueLocked() *segment {
first := e.sndQueue.Front()
if first != nil {
e.snd.writeList.PushBackList(&e.sndQueue)
e.sndBufInQueue = 0
}
+ return first
+}
- e.sndBufMu.Unlock()
-
+// Precondition: e.mu must be locked.
+func (e *endpoint) sendData(next *segment) {
// Initialize the next segment to write if it's currently nil.
if e.snd.writeNext == nil {
- e.snd.writeNext = first
+ e.snd.writeNext = next
}
// Push out any new packets.
e.snd.sendData()
-
- return nil
}
-func (e *endpoint) handleClose() *tcpip.Error {
+func (e *endpoint) handleClose() {
if !e.EndpointState().connected() {
- return nil
+ return
}
// Drain the send queue.
e.handleWrite()
// Mark send side as closed.
e.snd.closed = true
-
- return nil
}
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
// method must only be called from the protocol goroutine.
-func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+func (e *endpoint) resetConnectionLocked(err tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
e.hardError = err
- if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
+ switch err.(type) {
+ case *tcpip.ErrConnectionReset, *tcpip.ErrTimeout:
+ default:
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
// which can cause sndNxt to be outside the acceptable window on the
@@ -1053,7 +1058,7 @@ func (e *endpoint) drainClosingSegmentQueue() {
}
}
-func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
if e.rcv.acceptable(s.sequenceNumber, 0) {
// RFC 793, page 37 states that "in all states
// except SYN-SENT, all reset (RST) segments are
@@ -1081,7 +1086,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
- e.hardError = tcpip.ErrAborted
+ e.hardError = &tcpip.ErrAborted{}
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@@ -1094,14 +1099,14 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// handleSegment is invoked from the processor goroutine
// rather than the worker goroutine.
e.notifyProtocolGoroutine(notifyResetByPeer)
- return false, tcpip.ErrConnectionReset
+ return false, &tcpip.ErrConnectionReset{}
}
}
return true, nil
}
// handleSegments processes all inbound segments.
-func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
+func (e *endpoint) handleSegments(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
if e.EndpointState().closed() {
@@ -1148,7 +1153,7 @@ func (e *endpoint) probeSegment() {
// handleSegment handles a given segment and notifies the worker goroutine if
// if the connection should be terminated.
-func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) {
// Invoke the tcp probe if installed. The tcp probe function will update
// the TCPEndpointState after the segment is processed.
defer e.probeSegment()
@@ -1222,7 +1227,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
-func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+func (e *endpoint) keepaliveTimerExpired() tcpip.Error {
userTimeout := e.userTimeout
e.keepalive.Lock()
@@ -1236,13 +1241,13 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
@@ -1286,7 +1291,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
e.mu.Lock()
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1332,6 +1337,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
}
+ // Reaching this point means that we successfully completed the 3-way
+ // handshake with our peer.
+ //
+ // Completing the 3-way handshake is an indication that the route is valid
+ // and the remote is reachable as the only way we can complete a handshake
+ // is if our SYN reached the remote and their ACK reached us.
+ e.route.ConfirmReachable()
+
drained := e.drainDone != nil
if drained {
close(e.drainDone)
@@ -1344,19 +1357,25 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// wakes up.
funcs := []struct {
w *sleep.Waker
- f func() *tcpip.Error
+ f func() tcpip.Error
}{
{
w: &e.sndWaker,
- f: e.handleWrite,
+ f: func() tcpip.Error {
+ e.handleWrite()
+ return nil
+ },
},
{
w: &e.sndCloseWaker,
- f: e.handleClose,
+ f: func() tcpip.Error {
+ e.handleClose()
+ return nil
+ },
},
{
w: &closeWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
// This means the socket is being closed due
// to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
@@ -1367,10 +1386,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
{
w: &e.snd.resendWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
if !e.snd.retransmitTimerExpired() {
e.stack.Stats().TCP.EstablishedTimedout.Increment()
- return tcpip.ErrTimeout
+ return &tcpip.ErrTimeout{}
}
return nil
},
@@ -1381,7 +1400,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
{
w: &e.newSegmentWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
return e.handleSegments(false /* fastPath */)
},
},
@@ -1391,7 +1410,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
},
{
w: &e.notificationWaker,
- f: func() *tcpip.Error {
+ f: func() tcpip.Error {
n := e.fetchNotifications()
if n&notifyNonZeroReceiveWindow != 0 {
e.rcv.nonZeroWindow()
@@ -1408,11 +1427,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if n&notifyReset != 0 || n&notifyAbort != 0 {
- return tcpip.ErrConnectionAborted
+ return &tcpip.ErrConnectionAborted{}
}
if n&notifyResetByPeer != 0 {
- return tcpip.ErrConnectionReset
+ return &tcpip.ErrConnectionReset{}
}
if n&notifyClose != 0 && closeTimer == nil {
@@ -1491,7 +1510,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- cleanupOnError := func(err *tcpip.Error) {
+ cleanupOnError := func(err tcpip.Error) {
e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 1d1b01a6c..2d90246e4 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -15,11 +15,11 @@
package tcp_test
import (
+ "strings"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -37,7 +37,7 @@ func TestV4MappedConnectOnV6Only(t *testing.T) {
// Start connection attempt, it must fail.
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if err != tcpip.ErrNoRoute {
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
}
@@ -49,7 +49,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -156,7 +156,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
- if err != tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -391,7 +391,7 @@ func testV4Accept(t *testing.T, c *context.Context) {
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) {
t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
}
+ var r strings.Reader
data := "Don't panic"
- nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ r.Reset(data)
+ nep.Write(&r, tcpip.WriteOptions{})
b = c.GetPacket()
tcp = header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
@@ -523,7 +525,7 @@ func TestV6AcceptOnV6(t *testing.T) {
defer c.WQ.EventUnregister(&we)
var addr tcpip.FullAddress
_, _, err := c.EP.Accept(&addr)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -547,7 +549,7 @@ func TestV4AcceptOnV4(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
@@ -611,7 +613,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -633,7 +635,7 @@ func TestV4ListenCloseOnV4(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ea509ac73..6e4e26c39 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -386,12 +386,12 @@ type endpoint struct {
// hardError is meaningful only when state is stateError. It stores the
// error to be returned when read/write syscalls are called and the
// endpoint is in this state. hardError is protected by endpoint mu.
- hardError *tcpip.Error `state:".(string)"`
+ hardError tcpip.Error
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error `state:".(string)"`
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError tcpip.Error
// rcvReadMu synchronizes calls to Read.
//
@@ -557,7 +557,6 @@ type endpoint struct {
// When the send side is closed, the protocol goroutine is notified via
// sndCloseWaker, and sndClosed is set to true.
sndBufMu sync.Mutex `state:"nosave"`
- sndBufSize int
sndBufUsed int
sndClosed bool
sndBufInQueue seqnum.Size
@@ -869,7 +868,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
waiterQueue: waiterQueue,
state: StateInitial,
rcvBufSize: DefaultReceiveBufferSize,
- sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
keepalive: keepalive{
// Linux defaults.
@@ -882,13 +880,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
windowClamp: DefaultReceiveBufferSize,
maxSynRetries: DefaultSynRetries,
}
- e.ops.InitHandler(e)
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
e.ops.SetMulticastLoop(true)
e.ops.SetQuickAck(true)
+ e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */)
var ss tcpip.TCPSendBufferSizeRangeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- e.sndBufSize = ss.Default
+ e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs tcpip.TCPReceiveBufferSizeRangeOption
@@ -967,7 +966,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
- if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ sndBufSize := e.getSendBufferSize()
+ if e.sndClosed || e.sndBufUsed < sndBufSize {
result |= waiter.EventOut
}
e.sndBufMu.Unlock()
@@ -1059,7 +1059,7 @@ func (e *endpoint) Close() {
if isResetState {
// Close the endpoint without doing full shutdown and
// send a RST.
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
e.closeNoShutdownLocked()
// Wake up worker to close the endpoint.
@@ -1087,7 +1087,7 @@ func (e *endpoint) closeNoShutdownLocked() {
// in Listen() when trying to register.
if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
@@ -1161,7 +1161,7 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
@@ -1293,14 +1293,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Preconditions: e.mu must be held to call this function.
-func (e *endpoint) hardErrorLocked() *tcpip.Error {
+func (e *endpoint) hardErrorLocked() tcpip.Error {
err := e.hardError
e.hardError = nil
return err
}
// Preconditions: e.mu must be held to call this function.
-func (e *endpoint) lastErrorLocked() *tcpip.Error {
+func (e *endpoint) lastErrorLocked() tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1309,7 +1309,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error {
}
// LastError implements tcpip.Endpoint.LastError.
-func (e *endpoint) LastError() *tcpip.Error {
+func (e *endpoint) LastError() tcpip.Error {
e.LockUser()
defer e.UnlockUser()
if err := e.hardErrorLocked(); err != nil {
@@ -1319,7 +1319,7 @@ func (e *endpoint) LastError() *tcpip.Error {
}
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
-func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.LockUser()
e.lastErrorMu.Lock()
e.lastError = err
@@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
e.rcvReadMu.Lock()
defer e.rcvReadMu.Unlock()
@@ -1337,7 +1337,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// can remove segments from the list through commitRead().
first, last, serr := e.startRead()
if serr != nil {
- if serr == tcpip.ErrClosedForReceive {
+ if _, ok := serr.(*tcpip.ErrClosedForReceive); ok {
e.stats.ReadErrors.ReadClosed.Increment()
}
return tcpip.ReadResult{}, serr
@@ -1377,7 +1377,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// If something is read, we must report it. Report error when nothing is read.
if done == 0 && err != nil {
- return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ return tcpip.ReadResult{}, &tcpip.ErrBadBuffer{}
}
return tcpip.ReadResult{
Count: done,
@@ -1389,7 +1389,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// inclusive range of segments that can be read.
//
// Precondition: e.rcvReadMu must be held.
-func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
+func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1398,7 +1398,7 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return nil, nil, tcpip.ErrWouldBlock
+ return nil, nil, &tcpip.ErrWouldBlock{}
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1414,17 +1414,17 @@ func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
if err := e.hardErrorLocked(); err != nil {
return nil, nil, err
}
- return nil, nil, tcpip.ErrClosedForReceive
+ return nil, nil, &tcpip.ErrClosedForReceive{}
}
e.stats.ReadErrors.NotConnected.Increment()
- return nil, nil, tcpip.ErrNotConnected
+ return nil, nil, &tcpip.ErrNotConnected{}
}
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return nil, nil, tcpip.ErrClosedForReceive
+ return nil, nil, &tcpip.ErrClosedForReceive{}
}
- return nil, nil, tcpip.ErrWouldBlock
+ return nil, nil, &tcpip.ErrWouldBlock{}
}
return e.rcvList.Front(), e.rcvList.Back(), nil
@@ -1476,106 +1476,117 @@ func (e *endpoint) commitRead(done int) *segment {
// moment. If the endpoint is not writable then it returns an error
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
-func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
+func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
if err := e.hardErrorLocked(); err != nil {
return 0, err
}
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
case !s.connecting() && !s.connected():
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
case s.connecting():
// As per RFC793, page 56, a send request arriving when in connecting
// state, can be queued to be completed after the state becomes
// connected. Return an error code for the caller of endpoint Write to
// try again, until the connection handshake is complete.
- return 0, tcpip.ErrWouldBlock
+ return 0, &tcpip.ErrWouldBlock{}
}
// Check if the connection has already been closed for sends.
if e.sndClosed {
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
}
- avail := e.sndBufSize - e.sndBufUsed
+ sndBufSize := e.getSendBufferSize()
+ avail := sndBufSize - e.sndBufUsed
if avail <= 0 {
- return 0, tcpip.ErrWouldBlock
+ return 0, &tcpip.ErrWouldBlock{}
}
return avail, nil
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
e.LockUser()
- e.sndBufMu.Lock()
-
- avail, err := e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, err
- }
-
- // We can release locks while copying data.
- //
- // This is not possible if atomic is set, because we can't allow the
- // available buffer space to be consumed by some other caller while we
- // are copying data in.
- if !opts.Atomic {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- }
-
- // Fetch data.
- v, perr := p.Payload(avail)
- if perr != nil || len(v) == 0 {
- // Note that perr may be nil if len(v) == 0.
- if opts.Atomic {
- e.sndBufMu.Unlock()
- e.UnlockUser()
- }
- return 0, perr
- }
+ defer e.UnlockUser()
- if !opts.Atomic {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- e.LockUser()
+ nextSeg, n, err := func() (*segment, int, tcpip.Error) {
e.sndBufMu.Lock()
+ defer e.sndBufMu.Unlock()
+
avail, err := e.isEndpointWritableLocked()
if err != nil {
- e.sndBufMu.Unlock()
- e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
- return 0, err
+ return nil, 0, err
}
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ v, err := func() ([]byte, tcpip.Error) {
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ defer e.sndBufMu.Lock()
+
+ e.UnlockUser()
+ defer e.LockUser()
+ }
+
+ // Fetch data.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return nil, nil
+ }
+ v := make([]byte, avail)
+ if _, err := io.ReadFull(p, v); err != nil {
+ return nil, &tcpip.ErrBadBuffer{}
+ }
+ return v, nil
+ }()
+ if len(v) == 0 || err != nil {
+ return nil, 0, err
+ }
+
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return nil, 0, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
- }
- // Add data to the send queue.
- s := newOutgoingSegment(e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
- // Do the work inline.
- e.handleWrite()
- e.UnlockUser()
- return int64(len(v)), nil
+ return e.drainSendQueueLocked(), len(v), nil
+ }()
+ if err != nil {
+ return 0, err
+ }
+ e.sendData(nextSeg)
+ return int64(n), nil
}
// selectWindowLocked returns the new window without checking for shrinking or scaling
@@ -1682,8 +1693,16 @@ func (e *endpoint) OnCorkOptionSet(v bool) {
}
}
+func (e *endpoint) getSendBufferSize() int {
+ sndBufSize, err := e.ops.GetSendBufferSize()
+ if err != nil {
+ panic(fmt.Sprintf("e.ops.GetSendBufferSize() = %s", err))
+ }
+ return int(sndBufSize)
+}
+
// SetSockOptInt sets a socket option.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
// Lower 2 bits represents ECN bits. RFC 3168, section 23.1
const inetECNMask = 3
@@ -1711,7 +1730,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.MaxSegOption:
userMSS := v
if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
e.LockUser()
e.userMSS = uint16(userMSS)
@@ -1722,7 +1741,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// Return not supported if attempting to set this option to
// anything other than path MTU discovery disabled.
if v != tcpip.PMTUDiscoveryDont {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
case tcpip.ReceiveBufferSizeOption:
@@ -1775,31 +1794,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.rcvListMu.Unlock()
e.UnlockUser()
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss tcpip.TCPSendBufferSizeRangeOption
- if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
- panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &ss, err))
- }
-
- if v > ss.Max {
- v = ss.Max
- }
-
- if v < math.MaxInt32/SegOverheadFactor {
- v *= SegOverheadFactor
- if v < ss.Min {
- v = ss.Min
- }
- } else {
- v = math.MaxInt32
- }
-
- e.sndBufMu.Lock()
- e.sndBufSize = v
- e.sndBufMu.Unlock()
-
case tcpip.TTLOption:
e.LockUser()
e.ttl = uint8(v)
@@ -1807,7 +1801,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.TCPSynCountOption:
if v < 1 || v > 255 {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
e.LockUser()
e.maxSynRetries = uint8(v)
@@ -1823,7 +1817,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
default:
e.UnlockUser()
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
}
var rs tcpip.TCPReceiveBufferSizeRangeOption
@@ -1844,7 +1838,7 @@ func (e *endpoint) HasNIC(id int32) bool {
}
// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
@@ -1890,7 +1884,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
// Linux returns ENOENT when an invalid congestion
// control algorithm is specified.
- return tcpip.ErrNoSuchFile
+ return &tcpip.ErrNoSuchFile{}
case *tcpip.TCPLingerTimeoutOption:
e.LockUser()
@@ -1933,13 +1927,13 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
// readyReceiveSize returns the number of bytes ready to be received.
-func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+func (e *endpoint) readyReceiveSize() (int, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
// The endpoint cannot be in listen state.
if e.EndpointState() == StateListen {
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
e.rcvListMu.Lock()
@@ -1949,7 +1943,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.KeepaliveCountOption:
e.keepalive.Lock()
@@ -1985,12 +1979,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
- case tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- v := e.sndBufSize
- e.sndBufMu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
e.rcvListMu.Lock()
v := e.rcvBufSize
@@ -2019,24 +2007,38 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return 1, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
+ }
+}
+
+func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption {
+ info := tcpip.TCPInfoOption{}
+ e.LockUser()
+ snd := e.snd
+ if snd != nil {
+ // We do not calculate RTT before sending the data packets. If
+ // the connection did not send and receive data, then RTT will
+ // be zero.
+ snd.rtt.Lock()
+ info.RTT = snd.rtt.srtt
+ info.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+
+ info.RTO = snd.rto
+ info.CcState = snd.state
+ info.SndSsthresh = uint32(snd.sndSsthresh)
+ info.SndCwnd = uint32(snd.sndCwnd)
+ info.ReorderSeen = snd.rc.reorderSeen
}
+ e.UnlockUser()
+ return info
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.TCPInfoOption:
- *o = tcpip.TCPInfoOption{}
- e.LockUser()
- snd := e.snd
- e.UnlockUser()
- if snd != nil {
- snd.rtt.Lock()
- o.RTT = snd.rtt.srtt
- o.RTTVar = snd.rtt.rttvar
- snd.rtt.Unlock()
- }
+ *o = e.getTCPInfo()
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
@@ -2082,14 +2084,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
return nil
}
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -2098,18 +2100,20 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (*endpoint) Disconnect() *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Disconnect() tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Connect connects the endpoint to its peer.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
err := e.connect(addr, true, true)
- if err != nil && !err.IgnoreStats() {
- // Connect failed. Let's wake up any waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
+ if err != nil {
+ if !err.IgnoreStats() {
+ // Connect failed. Let's wake up any waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
}
return err
}
@@ -2120,7 +2124,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// created (so no new handshaking is done); for stack-accepted connections not
// yet accepted by the app, they are restored without running the main goroutine
// here.
-func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
@@ -2139,7 +2143,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return nil
}
// Otherwise return that it's already connected.
- return tcpip.ErrAlreadyConnected
+ return &tcpip.ErrAlreadyConnected{}
}
nicID := addr.NIC
@@ -2152,7 +2156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if nicID != 0 && nicID != e.boundNICID {
- return tcpip.ErrNoRoute
+ return &tcpip.ErrNoRoute{}
}
nicID = e.boundNICID
@@ -2164,16 +2168,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
case StateConnecting, StateSynSent, StateSynRecv:
// A connection request has already been issued but hasn't completed
// yet.
- return tcpip.ErrAlreadyConnecting
+ return &tcpip.ErrAlreadyConnecting{}
case StateError:
if err := e.hardErrorLocked(); err != nil {
return err
}
- return tcpip.ErrConnectionAborted
+ return &tcpip.ErrConnectionAborted{}
default:
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// Find a route to the desired destination.
@@ -2190,7 +2194,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -2229,12 +2233,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
- if err != tcpip.ErrPortInUse || !reuse {
+ if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
transEPID := e.ID
@@ -2278,9 +2282,9 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
id := e.ID
id.LocalPort = p
- if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
- if err == tcpip.ErrPortInUse {
+ if _, ok := err.(*tcpip.ErrPortInUse); ok {
return false, nil
}
return false, err
@@ -2335,23 +2339,23 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
- return tcpip.ErrConnectStarted
+ return &tcpip.ErrConnectStarted{}
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
return e.shutdownLocked(flags)
}
-func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.shutdownFlags |= flags
switch {
case e.EndpointState().connected():
@@ -2366,7 +2370,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
// Wake up worker to terminate loop.
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
@@ -2380,7 +2384,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
// Already closed.
e.sndBufMu.Unlock()
if e.EndpointState() == StateTimeWait {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
return nil
}
@@ -2413,22 +2417,24 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
return nil
default:
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
-func (e *endpoint) Listen(backlog int) *tcpip.Error {
+func (e *endpoint) Listen(backlog int) tcpip.Error {
err := e.listen(backlog)
- if err != nil && !err.IgnoreStats() {
- e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
- e.stats.FailedConnectionAttempts.Increment()
+ if err != nil {
+ if !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ }
}
return err
}
-func (e *endpoint) listen(backlog int) *tcpip.Error {
+func (e *endpoint) listen(backlog int) tcpip.Error {
e.LockUser()
defer e.UnlockUser()
@@ -2446,7 +2452,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// Adjust the size of the channel iff we can fix
// existing pending connections into the new one.
if len(e.acceptedChan) > backlog {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
if cap(e.acceptedChan) == backlog {
return nil
@@ -2478,11 +2484,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// Endpoint must be bound before it can transition to listen mode.
if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
@@ -2518,7 +2524,7 @@ func (e *endpoint) startAcceptedLoop() {
// to an endpoint previously set to listen mode.
//
// addr if not-nil will contain the peer address of the returned endpoint.
-func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2527,7 +2533,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
e.rcvListMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
if rcvClosed || e.EndpointState() != StateListen {
- return nil, nil, tcpip.ErrInvalidEndpointState
+ return nil, nil, &tcpip.ErrInvalidEndpointState{}
}
// Get the new accepted endpoint.
@@ -2538,7 +2544,7 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
case n = <-e.acceptedChan:
e.acceptCond.Signal()
default:
- return nil, nil, tcpip.ErrWouldBlock
+ return nil, nil, &tcpip.ErrWouldBlock{}
}
if peerAddr != nil {
*peerAddr = n.getRemoteAddress()
@@ -2547,19 +2553,19 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.
}
// Bind binds the endpoint to a specific local port and optionally address.
-func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
+func (e *endpoint) Bind(addr tcpip.FullAddress) (err tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
return e.bindLocked(addr)
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
if e.EndpointState() != StateInitial {
- return tcpip.ErrAlreadyBound
+ return &tcpip.ErrAlreadyBound{}
}
e.BindAddr = addr.Addr
@@ -2587,7 +2593,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
if len(addr.Addr) != 0 {
nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
e.ID.LocalAddress = addr.Addr
}
@@ -2604,7 +2610,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// demuxer. Further connected endpoints always have a remote
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
- if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
+ if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
return false
}
return true
@@ -2628,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -2640,12 +2646,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
if !e.EndpointState().connected() {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
return e.getRemoteAddress(), nil
@@ -2677,7 +2683,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool {
return true
}
-func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
// Update last error first.
e.lastErrorMu.Lock()
e.lastError = err
@@ -2726,26 +2732,27 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt
e.notifyProtocolGoroutine(notifyMTUChanged)
case stack.ControlNoRoute:
- e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
+ e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt)
case stack.ControlAddressUnreachable:
- e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt)
+ e.onICMPError(&tcpip.ErrNoRoute{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6AddressUnreachable), extra, pkt)
case stack.ControlNetworkUnreachable:
- e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
+ e.onICMPError(&tcpip.ErrNetworkUnreachable{}, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt)
}
}
// updateSndBufferUsage is called by the protocol goroutine when room opens up
// in the send buffer. The number of newly available bytes is v.
func (e *endpoint) updateSndBufferUsage(v int) {
+ sendBufferSize := e.getSendBufferSize()
e.sndBufMu.Lock()
- notify := e.sndBufUsed >= e.sndBufSize>>1
+ notify := e.sndBufUsed >= sendBufferSize>>1
e.sndBufUsed -= v
- // We only notify when there is half the sndBufSize available after
+ // We only notify when there is half the sendBufferSize available after
// a full buffer event occurs. This ensures that we don't wake up
// writers to queue just 1-2 segments and go back to sleep.
- notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ notify = notify && e.sndBufUsed < int(sendBufferSize)>>1
e.sndBufMu.Unlock()
if notify {
@@ -2957,8 +2964,9 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
// Copy endpoint send state.
+ sndBufSize := e.getSendBufferSize()
e.sndBufMu.Lock()
- s.SndBufSize = e.sndBufSize
+ s.SndBufSize = sndBufSize
s.SndBufUsed = e.sndBufUsed
s.SndClosed = e.sndClosed
s.SndBufInQueue = e.sndBufInQueue
@@ -3023,12 +3031,16 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
rc := &e.snd.rc
s.Sender.RACKState = stack.TCPRACKState{
- XmitTime: rc.xmitTime,
- EndSequence: rc.endSequence,
- FACK: rc.fack,
- RTT: rc.rtt,
- Reord: rc.reorderSeen,
- DSACKSeen: rc.dsackSeen,
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ Reord: rc.reorderSeen,
+ DSACKSeen: rc.dsackSeen,
+ ReoWnd: rc.reoWnd,
+ ReoWndIncr: rc.reoWndIncr,
+ ReoWndPersist: rc.reoWndPersist,
+ RTTSeq: rc.rttSeq,
}
return s
}
@@ -3103,3 +3115,17 @@ func (e *endpoint) Wait() {
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}
+
+// GetTCPSendBufferLimits is used to get send buffer size limits for TCP.
+func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption {
+ var ss tcpip.TCPSendBufferSizeRangeOption
+ if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err))
+ }
+
+ return tcpip.SendBufferSizeOption{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index ba67176b5..c21dbc682 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -55,9 +55,11 @@ func (e *endpoint) beforeSave() {
case epState.connected() || epState.handshake():
if !e.route.HasSaveRestoreCapability() {
if !e.route.HasDisconncetOkCapability() {
- panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
+ panic(&tcpip.ErrSaveRejection{
+ Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort),
+ })
}
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.resetConnectionLocked(&tcpip.ErrConnectionAborted{})
e.mu.Unlock()
e.Close()
e.mu.Lock()
@@ -179,14 +181,16 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
+ e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits)
e.segmentQueue.thaw()
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss tcpip.TCPSendBufferSizeRangeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ sendBufferSize := e.getSendBufferSize()
+ if sendBufferSize < ss.Min || sendBufferSize > ss.Max {
+ panic(fmt.Sprintf("endpoint sendBufferSize %d is outside the min and max allowed [%d, %d]", sendBufferSize, ss.Min, ss.Max))
}
}
@@ -228,7 +232,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
- if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
e.mu.Lock()
@@ -265,7 +270,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectedLoading.Wait()
listenLoading.Wait()
bind()
- if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted {
+ err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
panic("endpoint connecting failed: " + err.String())
}
connectingLoading.Done()
@@ -292,24 +298,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
-// saveLastError is invoked by stateify.
-func (e *endpoint) saveLastError() string {
- if e.lastError == nil {
- return ""
- }
-
- return e.lastError.String()
-}
-
-// loadLastError is invoked by stateify.
-func (e *endpoint) loadLastError(s string) {
- if s == "" {
- return
- }
-
- e.lastError = tcpip.StringToError(s)
-}
-
// saveRecentTSTime is invoked by stateify.
func (e *endpoint) saveRecentTSTime() unixTime {
return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
@@ -320,24 +308,6 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
e.recentTSTime = time.Unix(unix.second, unix.nano)
}
-// saveHardError is invoked by stateify.
-func (e *endpoint) saveHardError() string {
- if e.hardError == nil {
- return ""
- }
-
- return e.hardError.String()
-}
-
-// loadHardError is invoked by stateify.
-func (e *endpoint) loadHardError(s string) {
- if s == "" {
- return
- }
-
- e.hardError = tcpip.StringToError(s)
-}
-
// saveMeasureTime is invoked by stateify.
func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime {
return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 596178625..2f9fe7ee0 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -143,12 +143,12 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
// CreateEndpoint creates a TCP endpoint for the connection request, performing
// the 3-way handshake in the process.
-func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.segment == nil {
- return nil, tcpip.ErrInvalidEndpointState
+ return nil, &tcpip.ErrInvalidEndpointState{}
}
f := r.forwarder
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 1720370c9..04012cd40 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -161,13 +161,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
@@ -178,7 +178,7 @@ func (*protocol) MinimumPacketSize() int {
// ParsePorts returns the source and destination ports stored in the given tcp
// packet.
-func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
@@ -216,7 +216,7 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
// replyWithReset replies to the given segment with a reset segment.
//
// If the passed TTL is 0, then the route's default TTL will be used.
-func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error {
+func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error {
route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
if err != nil {
return err
@@ -261,7 +261,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
+func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPSACKEnabled:
p.mu.Lock()
@@ -283,7 +283,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPSendBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.sendBufferSize = *v
@@ -292,7 +292,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPReceiveBufferSizeRangeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.recvBufferSize = *v
@@ -310,7 +310,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
}
// linux returns ENOENT when an invalid congestion control
// is specified.
- return tcpip.ErrNoSuchFile
+ return &tcpip.ErrNoSuchFile{}
case *tcpip.TCPModerateReceiveBufferOption:
p.mu.Lock()
@@ -340,7 +340,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPTimeWaitReuseOption:
if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.timeWaitReuse = *v
@@ -381,7 +381,7 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
case *tcpip.TCPSynRetriesOption:
if *v < 1 || *v > 255 {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
p.mu.Lock()
p.synRetries = uint8(*v)
@@ -389,12 +389,12 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpi
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
// Option implements stack.TransportProtocol.Option.
-func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
+func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Error {
switch v := option.(type) {
case *tcpip.TCPSACKEnabled:
p.mu.RLock()
@@ -493,7 +493,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E
return nil
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
index 307bacca5..d85cb405a 100644
--- a/pkg/tcpip/transport/tcp/rack.go
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -22,12 +22,21 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
-// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as
-// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.
-// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is
-// 1, PTO is inflated by WCDelAckT time to compensate for a potential long
-// delayed ACK timer at the receiver.
-const wcDelayedACKTimeout = 200 * time.Millisecond
+const (
+ // wcDelayedACKTimeout is the recommended maximum delayed ACK timer
+ // value as defined in the RFC. It stands for worst case delayed ACK
+ // timer (WCDelAckT). When FlightSize is 1, PTO is inflated by
+ // WCDelAckT time to compensate for a potential long delayed ACK timer
+ // at the receiver.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.
+ wcDelayedACKTimeout = 200 * time.Millisecond
+
+ // tcpRACKRecoveryThreshold is the number of loss recoveries for which
+ // the reorder window is inflated and after that the reorder window is
+ // reset to its initial value of minRTT/4.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2.
+ tcpRACKRecoveryThreshold = 16
+)
// RACK is a loss detection algorithm used in TCP to detect packet loss and
// reordering using transmission timestamp of the packets instead of packet or
@@ -44,6 +53,11 @@ type rackControl struct {
// endSequence is the ending TCP sequence number of rackControl.seg.
endSequence seqnum.Value
+ // exitedRecovery indicates if the connection is exiting loss recovery.
+ // This flag is set if the sender is leaving the recovery after
+ // receiving an ACK and is reset during updating of reorder window.
+ exitedRecovery bool
+
// fack is the highest selectively or cumulatively acknowledged
// sequence.
fack seqnum.Value
@@ -51,15 +65,30 @@ type rackControl struct {
// minRTT is the estimated minimum RTT of the connection.
minRTT time.Duration
+ // reorderSeen indicates if reordering has been detected on this
+ // connection.
+ reorderSeen bool
+
+ // reoWnd is the reordering window time used for recording packet
+ // transmission times. It is used to defer the moment at which RACK
+ // marks a packet lost.
+ reoWnd time.Duration
+
+ // reoWndIncr is the multiplier applied to adjust reorder window.
+ reoWndIncr uint8
+
+ // reoWndPersist is the number of loss recoveries before resetting
+ // reorder window.
+ reoWndPersist int8
+
// rtt is the RTT of the most recently delivered packet on the
// connection (either cumulatively acknowledged or selectively
// acknowledged) that was not marked invalid as a possible spurious
// retransmission.
rtt time.Duration
- // reorderSeen indicates if reordering has been detected on this
- // connection.
- reorderSeen bool
+ // rttSeq is the SND.NXT when rtt is updated.
+ rttSeq seqnum.Value
// xmitTime is the latest transmission timestamp of rackControl.seg.
xmitTime time.Time `state:".(unixTime)"`
@@ -75,29 +104,36 @@ type rackControl struct {
// tlpHighRxt the value of sender.sndNxt at the time of sending
// a TLP retransmission.
tlpHighRxt seqnum.Value
+
+ // snd is a reference to the sender.
+ snd *sender
}
// init initializes RACK specific fields.
-func (rc *rackControl) init() {
+func (rc *rackControl) init(snd *sender, iss seqnum.Value) {
+ rc.fack = iss
+ rc.reoWndIncr = 1
+ rc.snd = snd
rc.probeTimer.init(&rc.probeWaker)
}
// update will update the RACK related fields when an ACK has been received.
-// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
-func (rc *rackControl) update(seg *segment, ackSeg *segment, offset uint32) {
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2
+func (rc *rackControl) update(seg *segment, ackSeg *segment) {
rtt := time.Now().Sub(seg.xmitTime)
+ tsOffset := rc.snd.ep.tsOffset
// If the ACK is for a retransmitted packet, do not update if it is a
// spurious inference which is determined by below checks:
- // 1. When Timestamping option is available, if the TSVal is less than the
- // transmit time of the most recent retransmitted packet.
+ // 1. When Timestamping option is available, if the TSVal is less than
+ // the transmit time of the most recent retransmitted packet.
// 2. When RTT calculated for the packet is less than the smoothed RTT
// for the connection.
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// step 2
if seg.xmitCount > 1 {
if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
- if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, tsOffset) {
return
}
}
@@ -149,9 +185,8 @@ func (rc *rackControl) detectReorder(seg *segment) {
}
}
-// setDSACKSeen updates rack control if duplicate SACK is seen by the connection.
-func (rc *rackControl) setDSACKSeen() {
- rc.dsackSeen = true
+func (rc *rackControl) setDSACKSeen(dsackSeen bool) {
+ rc.dsackSeen = dsackSeen
}
// shouldSchedulePTO dictates whether we should schedule a PTO or not.
@@ -162,7 +197,7 @@ func (s *sender) shouldSchedulePTO() bool {
// The connection supports SACK.
s.ep.sackPermitted &&
// The connection is not in loss recovery.
- (s.state != RTORecovery && s.state != SACKRecovery) &&
+ (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) &&
// The connection has no SACKed sequences in the SACK scoreboard.
s.ep.scoreboard.Sacked() == 0
}
@@ -193,7 +228,7 @@ func (s *sender) schedulePTO() {
// probeTimerExpired is the same as TLP_send_probe() as defined in
// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2.
-func (s *sender) probeTimerExpired() *tcpip.Error {
+func (s *sender) probeTimerExpired() tcpip.Error {
if !s.rc.probeTimer.checkExpiration() {
return nil
}
@@ -272,3 +307,82 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) {
}
}
}
+
+// updateRACKReorderWindow updates the reorder window.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+// * Step 4: Update RACK reordering window
+// To handle the prevalent small degree of reordering, RACK.reo_wnd serves as
+// an allowance for settling time before marking a packet lost. RACK starts
+// initially with a conservative window of min_RTT/4. If no reordering has
+// been observed RACK uses reo_wnd of zero during loss recovery, in order to
+// retransmit quickly, or when the number of DUPACKs exceeds the classic
+// DUPACKthreshold.
+func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) {
+ dsackSeen := rc.dsackSeen
+ snd := rc.snd
+
+ // React to DSACK once per round trip.
+ // If SND.UNA < RACK.rtt_seq:
+ // RACK.dsack = false
+ if snd.sndUna.LessThan(rc.rttSeq) {
+ dsackSeen = false
+ }
+
+ // If RACK.dsack:
+ // RACK.reo_wnd_incr += 1
+ // RACK.dsack = false
+ // RACK.rtt_seq = SND.NXT
+ // RACK.reo_wnd_persist = 16
+ if dsackSeen {
+ rc.reoWndIncr++
+ dsackSeen = false
+ rc.rttSeq = snd.sndNxt
+ rc.reoWndPersist = tcpRACKRecoveryThreshold
+ } else if rc.exitedRecovery {
+ // Else if exiting loss recovery:
+ // RACK.reo_wnd_persist -= 1
+ // If RACK.reo_wnd_persist <= 0:
+ // RACK.reo_wnd_incr = 1
+ rc.reoWndPersist--
+ if rc.reoWndPersist <= 0 {
+ rc.reoWndIncr = 1
+ }
+ rc.exitedRecovery = false
+ }
+
+ // Reorder window is zero during loss recovery, or when the number of
+ // DUPACKs exceeds the classic DUPACKthreshold.
+ // If RACK.reord is FALSE:
+ // If in loss recovery: (If in fast or timeout recovery)
+ // RACK.reo_wnd = 0
+ // Return
+ // Else if RACK.pkts_sacked >= RACK.dupthresh:
+ // RACK.reo_wnd = 0
+ // return
+ if !rc.reorderSeen {
+ if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery {
+ rc.reoWnd = 0
+ return
+ }
+
+ if snd.sackedOut >= nDupAckThreshold {
+ rc.reoWnd = 0
+ return
+ }
+ }
+
+ // Calculate reorder window.
+ // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr
+ // RACK.reo_wnd = min(RACK.reo_wnd, SRTT)
+ snd.rtt.Lock()
+ srtt := snd.rtt.srtt
+ snd.rtt.Unlock()
+ rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr))
+ if srtt < rc.reoWnd {
+ rc.reoWnd = srtt
+ }
+}
+
+func (rc *rackControl) exitRecovery() {
+ rc.exitedRecovery = true
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 405a6dce7..7a7c402c4 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -347,7 +347,7 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) {
r.ep.rcvListMu.Lock()
rcvClosed := r.ep.rcvClosed || r.closed
r.ep.rcvListMu.Unlock()
@@ -395,7 +395,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
- return true, tcpip.ErrConnectionAborted
+ return true, &tcpip.ErrConnectionAborted{}
}
if state == StateFinWait1 {
break
@@ -424,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// the last actual data octet in a segment in
// which it occurs.
if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
- return true, tcpip.ErrConnectionAborted
+ return true, &tcpip.ErrConnectionAborted{}
}
}
@@ -443,7 +443,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
state := r.ep.EndpointState()
closed := r.ep.closed
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 079d90848..dfc8fd248 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -48,28 +48,6 @@ const (
MaxRetries = 15
)
-// ccState indicates the current congestion control state for this sender.
-type ccState int
-
-const (
- // Open indicates that the sender is receiving acks in order and
- // no loss or dupACK's etc have been detected.
- Open ccState = iota
- // RTORecovery indicates that an RTO has occurred and the sender
- // has entered an RTO based recovery phase.
- RTORecovery
- // FastRecovery indicates that the sender has entered FastRecovery
- // based on receiving nDupAck's. This state is entered only when
- // SACK is not in use.
- FastRecovery
- // SACKRecovery indicates that the sender has entered SACK based
- // recovery.
- SACKRecovery
- // Disorder indicates the sender either received some SACK blocks
- // or dupACK's.
- Disorder
-)
-
// congestionControl is an interface that must be implemented by any supported
// congestion control algorithm.
type congestionControl interface {
@@ -204,7 +182,7 @@ type sender struct {
maxSentAck seqnum.Value
// state is the current state of congestion control for this endpoint.
- state ccState
+ state tcpip.CongestionControlState
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
@@ -280,14 +258,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
highRxt: iss,
rescueRxt: iss,
},
- rc: rackControl{
- fack: iss,
- },
gso: ep.gso != nil,
}
- s.rc.init()
-
if s.gso {
s.ep.gso.MSS = uint16(maxPayloadSize)
}
@@ -295,6 +268,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
s.cc = s.initCongestionControl(ep.cc)
s.lr = s.initLossRecovery()
+ s.rc.init(s, iss)
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
@@ -593,7 +567,7 @@ func (s *sender) retransmitTimerExpired() bool {
s.leaveRecovery()
}
- s.state = RTORecovery
+ s.state = tcpip.RTORecovery
s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
@@ -1018,7 +992,7 @@ func (s *sender) sendData() {
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
+ if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
if s.sndCwnd > InitialCwnd {
s.sndCwnd = InitialCwnd
}
@@ -1062,14 +1036,14 @@ func (s *sender) enterRecovery() {
s.fr.highRxt = s.sndUna
s.fr.rescueRxt = s.sndUna
if s.ep.sackPermitted {
- s.state = SACKRecovery
+ s.state = tcpip.SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
// Set TLPRxtOut to false according to
// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1.
s.rc.tlpRxtOut = false
return
}
- s.state = FastRecovery
+ s.state = tcpip.FastRecovery
s.ep.stack.Stats().TCP.FastRecovery.Increment()
}
@@ -1080,7 +1054,6 @@ func (s *sender) leaveRecovery() {
// Deflate cwnd. It had been artificially inflated when new dups arrived.
s.sndCwnd = s.sndSsthresh
-
s.cc.PostRecovery()
}
@@ -1166,7 +1139,7 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) {
s.fr.highRxt = s.sndUna - 1
// Do run SetPipe() to calculate the outstanding segments.
s.SetPipe()
- s.state = Disorder
+ s.state = tcpip.Disorder
return false
}
@@ -1217,11 +1190,13 @@ func (s *sender) isDupAck(seg *segment) bool {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
+ s.rc.setDSACKSeen(false)
+
// Look for DSACK block.
idx := 0
n := len(rcvdSeg.parsedOptions.SACKBlocks)
if checkDSACK(rcvdSeg) {
- s.rc.setDSACKSeen()
+ s.rc.setDSACKSeen(true)
idx = 1
n--
}
@@ -1242,7 +1217,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) {
for _, sb := range sackBlocks {
for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 {
if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked {
- s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
seg.acked = true
s.sackedOut += s.pCount(seg, s.maxPayloadSize)
@@ -1412,6 +1387,17 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
acked := s.sndUna.Size(ack)
s.sndUna = ack
+ // The remote ACK-ing at least 1 byte is an indication that we have a
+ // full-duplex connection to the remote as the only way we will receive an
+ // ACK is if the remote received data that we previously sent.
+ //
+ // As of writing, linux seems to only confirm a route as reachable when
+ // forward progress is made which is indicated by an ACK that removes data
+ // from the retransmit queue.
+ if acked > 0 {
+ s.ep.route.ConfirmReachable()
+ }
+
ackLeft := acked
originalOutstanding := s.outstanding
for ackLeft > 0 {
@@ -1435,7 +1421,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Update the RACK fields if SACK is enabled.
if s.ep.sackPermitted && !seg.acked {
- s.rc.update(seg, rcvdSeg, s.ep.tsOffset)
+ s.rc.update(seg, rcvdSeg)
s.rc.detectReorder(seg)
}
@@ -1464,7 +1450,11 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
if !s.fr.active {
s.cc.Update(originalOutstanding - s.outstanding)
if s.fr.last.LessThan(s.sndUna) {
- s.state = Open
+ s.state = tcpip.Open
+ // Update RACK when we are exiting fast or RTO
+ // recovery as described in the RFC
+ // draft-ietf-tcpm-rack-08 Section-7.2 Step 4.
+ s.rc.exitRecovery()
}
}
@@ -1488,6 +1478,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
}
+ // Update RACK reorder window.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // * Upon receiving an ACK:
+ // * Step 4: Update RACK reordering window
+ s.rc.updateRACKReorderWindow(rcvdSeg)
+
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
if s.fr.active {
@@ -1508,7 +1504,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
}
// sendSegment sends the specified segment.
-func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+func (s *sender) sendSegment(seg *segment) tcpip.Error {
if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
@@ -1539,7 +1535,7 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
-func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index f7aaee23f..ced3a9c58 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -21,13 +21,13 @@
package tcp_test
import (
+ "bytes"
"fmt"
"math"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
@@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
-
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in two shots. Packets will only be written at the
// MTU size though.
- half := data[:len(data)/2]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data[:len(data)/2])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
- half = data[len(data)/2:]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ r.Reset(data[len(data)/2:])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 342eb5eb8..a6a26b705 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -15,11 +15,12 @@
package tcp_test
import (
+ "bytes"
+ "fmt"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -61,14 +62,16 @@ func TestRACKUpdate(t *testing.T) {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(maxPayload)
+ data := make([]byte, maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
xmitTime = time.Now()
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -114,13 +117,15 @@ func TestRACKDetectReorder(t *testing.T) {
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(ackNumToVerify * maxPayload)
+ data := make([]byte, ackNumToVerify*maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -141,17 +146,19 @@ func TestRACKDetectReorder(t *testing.T) {
<-probeDone
}
-func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View {
+func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(numPackets * maxPayload)
+ data := make([]byte, numPackets*maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -528,3 +535,64 @@ func TestRACKWithInvalidDSACKBlock(t *testing.T) {
// ACK before the test completes.
<-probeDone
}
+
+func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan error) {
+ var n int
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that RACK detects DSACK.
+ n++
+ if n < numACK {
+ return
+ }
+
+ if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT {
+ probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT)
+ return
+ }
+
+ if state.Sender.RACKState.ReoWndIncr != 1 {
+ probeDone <- fmt.Errorf("got RACKState.ReoWndIncr: %v, want: 1", state.Sender.RACKState.ReoWndIncr)
+ return
+ }
+
+ if state.Sender.RACKState.ReoWndPersist > 0 {
+ probeDone <- fmt.Errorf("got RACKState.ReoWndPersist: %v, want: greater than 0", state.Sender.RACKState.ReoWndPersist)
+ return
+ }
+ probeDone <- nil
+ })
+}
+
+func TestRACKCheckReorderWindow(t *testing.T) {
+ c := context.New(t, uint32(mtu))
+ defer c.Cleanup()
+
+ probeDone := make(chan error)
+ const ackNumToVerify = 3
+ addReorderWindowCheckerProbe(c, ackNumToVerify, probeDone)
+
+ const numPackets = 7
+ sendAndReceive(t, c, numPackets)
+
+ // Send ACK for #1 packet.
+ bytesRead := maxPayload
+ seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ c.SendAck(seq, bytesRead)
+
+ // Missing [2-6] packets and SACK #7 packet.
+ seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1)
+ start := c.IRS.Add(1 + seqnum.Size(6*maxPayload))
+ end := start.Add(seqnum.Size(maxPayload))
+ c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}})
+
+ // Received delayed packets [2-6] which indicates there is reordering
+ // in the connection.
+ bytesRead += 6 * maxPayload
+ c.SendAck(seq, bytesRead)
+
+ // Wait for the probe function to finish processing the ACK before the
+ // test completes.
+ if err := <-probeDone; err != nil {
+ t.Fatalf("unexpected values for RACK variables: %v", err)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 6635bb815..5024bc925 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -15,6 +15,7 @@
package tcp_test
import (
+ "bytes"
"fmt"
"log"
"reflect"
@@ -22,7 +23,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) {
createConnectedWithSACKAndTS(c)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 93683b921..da2730e27 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -19,6 +19,7 @@ import (
"fmt"
"io/ioutil"
"math"
+ "strings"
"testing"
"time"
@@ -26,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -48,7 +48,7 @@ type endpointTester struct {
}
// CheckReadError issues a read to the endpoint and checking for an error.
-func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
+func (e *endpointTester) CheckReadError(t *testing.T, want tcpip.Error) {
t.Helper()
res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
if got != want {
@@ -87,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha
}
for w.N != 0 {
_, err := e.ep.Read(&w, tcpip.ReadOptions{})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for receive to be notified.
select {
case <-notifyRead:
@@ -128,8 +128,11 @@ func TestGiveUpConnect(t *testing.T) {
wq.EventRegister(&waitEntry, waiter.EventHUp)
defer wq.EventUnregister(&waitEntry)
- if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
// Close the connection, wait for completion.
@@ -140,8 +143,11 @@ func TestGiveUpConnect(t *testing.T) {
// Call Connect again to retreive the handshake failure status
// and stats updates.
- if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrAborted); !ok {
+ t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{})
+ }
}
if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
@@ -194,8 +200,11 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
- if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrNoRoute); !ok {
+ t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{})
+ }
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
@@ -211,7 +220,7 @@ func TestCloseWithoutConnect(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -384,7 +393,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -925,8 +934,11 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
- if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted {
- t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ {
+ err := c.EP.Connect(connectAddr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ }
}
// Receive SYN packet with our user supplied MSS.
@@ -1347,10 +1359,9 @@ func TestTOSV4(t *testing.T) {
testV4Connect(t, c, checker.TOS(tos, 0))
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1396,10 +1407,9 @@ func TestTrafficClassV6(t *testing.T) {
testV6Connect(t, c, checker.TOS(tos, 0))
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1444,7 +1454,8 @@ func TestConnectBindToDevice(t *testing.T) {
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
t.Fatalf("unexpected return value from Connect: %s", err)
}
@@ -1504,8 +1515,9 @@ func TestSynSent(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
- if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
- t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ err := c.EP.Connect(addr)
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{})
}
// Receive SYN packet.
@@ -1550,9 +1562,9 @@ func TestSynSent(t *testing.T) {
ept := endpointTester{c.EP}
if test.reset {
- ept.CheckReadError(t, tcpip.ErrConnectionRefused)
+ ept.CheckReadError(t, &tcpip.ErrConnectionRefused{})
} else {
- ept.CheckReadError(t, tcpip.ErrAborted)
+ ept.CheckReadError(t, &tcpip.ErrAborted{})
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
@@ -1578,7 +1590,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1603,7 +1615,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send the first 3 bytes now.
c.SendPacket(data[:3], &context.Headers{
@@ -1642,7 +1654,7 @@ func TestOutOfOrderFlood(t *testing.T) {
c.CreateConnected(789, 30000, rcvBufSz)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1718,7 +1730,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1786,7 +1798,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1868,13 +1880,13 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
- ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
@@ -1893,7 +1905,7 @@ func TestFullWindowReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
@@ -2054,7 +2066,7 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
@@ -2176,10 +2188,9 @@ func TestSimpleSend(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2217,10 +2228,9 @@ func TestZeroWindowSend(t *testing.T) {
c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2285,10 +2295,9 @@ func TestScaledWindowConnect(t *testing.T) {
})
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2317,10 +2326,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2376,7 +2384,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -2391,10 +2399,9 @@ func TestScaledWindowAccept(t *testing.T) {
}
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2450,7 +2457,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -2465,10 +2472,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
}
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2632,9 +2638,10 @@ func TestSegmentMerging(t *testing.T) {
// Send tcp.InitialCwnd number of segments to fill up
// InitialWindow but don't ACK. That should prevent
// anymore packets from going out.
+ var r bytes.Reader
for i := 0; i < tcp.InitialCwnd; i++ {
- view := buffer.NewViewFromBytes([]byte{0})
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset([]byte{0})
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2644,8 +2651,8 @@ func TestSegmentMerging(t *testing.T) {
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2714,8 +2721,9 @@ func TestDelay(t *testing.T) {
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2761,8 +2769,9 @@ func TestUndelay(t *testing.T) {
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2845,8 +2854,9 @@ func TestMSSNotDelayed(t *testing.T) {
allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2894,10 +2904,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
data[i] = byte(i)
}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2963,7 +2972,7 @@ func TestSetTTL(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -2973,8 +2982,11 @@ func TestSetTTL(t *testing.T) {
t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("unexpected return value from Connect: %s", err)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("unexpected return value from Connect: %s", err)
+ }
}
// Receive SYN packet.
@@ -3034,7 +3046,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -3090,7 +3102,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -3115,9 +3127,9 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
defer c.Cleanup()
s := c.Stack()
- ch := make(chan *tcpip.Error, 1)
+ ch := make(chan tcpip.Error, 1)
f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = r.CreateEndpoint(&c.WQ)
ch <- err
})
@@ -3146,7 +3158,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -3165,8 +3177,11 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventOut)
defer c.WQ.EventUnregister(&we)
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
// Receive SYN packet.
@@ -3276,22 +3291,23 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err {
- case tcpip.ErrWouldBlock:
+ switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case *tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
+ t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{})
}
break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
- case tcpip.ErrConnectionReset:
+ case *tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
}
}
@@ -3328,9 +3344,11 @@ func TestSendOnResetConnection(t *testing.T) {
time.Sleep(1 * time.Second)
// Try to write.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ var r bytes.Reader
+ r.Reset(make([]byte, 10))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if _, ok := err.(*tcpip.ErrConnectionReset); !ok {
+ t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{})
}
}
@@ -3352,7 +3370,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
defer c.WQ.EventUnregister(&waitEntry)
- _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(make([]byte, 1))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3409,7 +3429,9 @@ func TestMaxRTO(t *testing.T) {
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
- _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(make([]byte, 1))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3458,7 +3480,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
}
- if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(make([]byte, tc.size))
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
pkt := c.GetPacket()
@@ -3595,8 +3619,10 @@ func TestFinWithNoPendingData(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3667,9 +3693,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
- view := buffer.NewView(10)
+ view := make([]byte, 10)
+ var r bytes.Reader
for i := tcp.InitialCwnd; i > 0; i-- {
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
}
@@ -3754,8 +3782,10 @@ func TestFinWithPendingData(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3781,7 +3811,8 @@ func TestFinWithPendingData(t *testing.T) {
})
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3841,8 +3872,10 @@ func TestFinWithPartialAck(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3879,7 +3912,8 @@ func TestFinWithPartialAck(t *testing.T) {
)
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3985,8 +4019,10 @@ func scaledSendWindow(t *testing.T, scale uint8) {
})
// Send some data. Check that it's capped by the window size.
- view := buffer.NewView(65535)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 65535)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4170,7 +4206,7 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -4249,10 +4285,13 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
- ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, &tcpip.ErrClosedForReceive{})
var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
+ {
+ _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true})
+ if _, ok := err.(*tcpip.ErrClosedForReceive); !ok {
+ t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{})
+ }
}
}
@@ -4263,7 +4302,7 @@ func TestReusePort(t *testing.T) {
defer c.Cleanup()
// First case, just an endpoint that was bound.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
t.Fatalf("NewEndpoint failed; %s", err)
@@ -4293,8 +4332,11 @@ func TestReusePort(t *testing.T) {
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
- if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
c.EP.Close()
@@ -4351,9 +4393,9 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ s, err := ep.SocketOptions().GetSendBufferSize()
if err != nil {
- t.Fatalf("GetSockOpt failed: %s", err)
+ t.Fatalf("GetSendBufferSize failed: %s", err)
}
if int(s) != v {
@@ -4459,9 +4501,7 @@ func TestMinMaxBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 149); err != nil {
- t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
- }
+ ep.SocketOptions().SetSendBufferSize(149, true)
checkSendBufferSize(t, ep, 300)
@@ -4473,9 +4513,7 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Values above max are capped at max and then doubled.
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
- if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
- t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
- }
+ ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
// Values above max are capped at max and then doubled.
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
@@ -4505,11 +4543,11 @@ func TestBindToDeviceOption(t *testing.T) {
testActions := []struct {
name string
setBindToDevice *tcpip.NICID
- setBindToDeviceError *tcpip.Error
+ setBindToDeviceError tcpip.Error
getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
{"BindToExistent", nicIDPtr(321), nil, 321},
{"UnbindToDevice", nicIDPtr(0), nil, 0},
}
@@ -4529,7 +4567,7 @@ func TestBindToDeviceOption(t *testing.T) {
}
}
-func makeStack() (*stack.Stack, *tcpip.Error) {
+func makeStack() (*stack.Stack, tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@@ -4599,8 +4637,11 @@ func TestSelfConnect(t *testing.T) {
wq.EventRegister(&waitEntry, waiter.EventOut)
defer wq.EventUnregister(&waitEntry)
- if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
+ {
+ err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
+ t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{})
+ }
}
<-notifyCh
@@ -4610,9 +4651,9 @@ func TestSelfConnect(t *testing.T) {
// Write something.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4752,9 +4793,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
- want := tcpip.ErrConnectStarted
+ var want tcpip.Error = &tcpip.ErrConnectStarted{}
if collides {
- want = tcpip.ErrNoPortAvailable
+ want = &tcpip.ErrNoPortAvailable{}
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
@@ -4785,12 +4826,13 @@ func TestPathMTUDiscovery(t *testing.T) {
// Send 3200 bytes of data.
const writeSize = 3200
- data := buffer.NewView(writeSize)
+ data := make([]byte, writeSize)
for i := range data {
data[i] = byte(i)
}
-
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4878,11 +4920,11 @@ func TestTCPEndpointProbe(t *testing.T) {
func TestStackSetCongestionControl(t *testing.T) {
testCases := []struct {
cc tcpip.CongestionControlOption
- err *tcpip.Error
+ err tcpip.Error
}{
{"reno", nil},
{"cubic", nil},
- {"blahblah", tcpip.ErrNoSuchFile},
+ {"blahblah", &tcpip.ErrNoSuchFile{}},
}
for _, tc := range testCases {
@@ -4964,11 +5006,11 @@ func TestStackSetAvailableCongestionControl(t *testing.T) {
func TestEndpointSetCongestionControl(t *testing.T) {
testCases := []struct {
cc tcpip.CongestionControlOption
- err *tcpip.Error
+ err tcpip.Error
}{
{"reno", nil},
{"cubic", nil},
- {"blahblah", tcpip.ErrNoSuchFile},
+ {"blahblah", &tcpip.ErrNoSuchFile{}},
}
for _, connected := range []bool{false, true} {
@@ -4978,7 +5020,7 @@ func TestEndpointSetCongestionControl(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5074,12 +5116,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
- view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 3)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5163,7 +5207,7 @@ func TestKeepalive(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
- ept.CheckReadError(t, tcpip.ErrTimeout)
+ ept.CheckReadError(t, &tcpip.ErrTimeout{})
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
@@ -5270,7 +5314,7 @@ func TestListenBacklogFull(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5313,7 +5357,7 @@ func TestListenBacklogFull(t *testing.T) {
for i := 0; i < listenBacklog; i++ {
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5330,7 +5374,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5342,7 +5386,7 @@ func TestListenBacklogFull(t *testing.T) {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5358,7 +5402,9 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ var r strings.Reader
+ r.Reset(data)
+ newEP.Write(&r, tcpip.WriteOptions{})
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
@@ -5583,7 +5629,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5658,7 +5704,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
defer c.WQ.EventUnregister(&we)
newEP, _, err := c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5674,7 +5720,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ var r strings.Reader
+ r.Reset(data)
+ newEP.Write(&r, tcpip.WriteOptions{})
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
@@ -5692,7 +5740,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5733,7 +5781,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
defer c.WQ.EventUnregister(&we)
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -5749,7 +5797,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Now verify that there are no more connections that can be accepted.
_, _, err = c.EP.Accept(nil)
- if err != tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
select {
case <-ch:
t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP)
@@ -5763,7 +5811,7 @@ func TestSYNRetransmit(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5807,7 +5855,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
defer c.Cleanup()
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %s", err)
@@ -5882,12 +5930,13 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
})
newEP, _, err := c.EP.Accept(nil)
-
- if err != nil && err != tcpip.ErrWouldBlock {
+ switch err.(type) {
+ case nil, *tcpip.ErrWouldBlock:
+ default:
t.Fatalf("Accept failed: %s", err)
}
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Try to accept the connections in the backlog.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -5908,7 +5957,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil {
+ var r strings.Reader
+ r.Reset(data)
+ if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5953,7 +6004,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
// Verify that there is only one acceptable connection at this point.
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6023,7 +6074,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
// Now check that there is one acceptable connections.
_, _, err = c.EP.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6055,7 +6106,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
ept := endpointTester{ep}
- ept.CheckReadError(t, tcpip.ErrNotConnected)
+ ept.CheckReadError(t, &tcpip.ErrNotConnected{})
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
@@ -6075,7 +6126,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
defer wq.EventUnregister(&we)
aep, _, err := ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6091,8 +6142,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
- t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
+ {
+ err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok {
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{})
+ }
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
@@ -6211,7 +6265,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// window increases to the full available buffer size.
for {
_, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
break
}
}
@@ -6335,7 +6389,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalCopied := 0
for {
res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
break
}
totalCopied += res.Count
@@ -6527,7 +6581,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6646,7 +6700,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6753,7 +6807,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6843,7 +6897,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
// Try to accept the connection.
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -6917,7 +6971,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -7067,7 +7121,7 @@ func TestTCPCloseWithData(t *testing.T) {
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
@@ -7103,10 +7157,10 @@ func TestTCPCloseWithData(t *testing.T) {
// Now write a few bytes and then close the endpoint.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7204,8 +7258,10 @@ func TestTCPUserTimeout(t *testing.T) {
}
// Send some data and wait before ACKing it.
- view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 3)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7256,7 +7312,7 @@ func TestTCPUserTimeout(t *testing.T) {
)
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrTimeout)
+ ept.CheckReadError(t, &tcpip.ErrTimeout{})
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
@@ -7300,7 +7356,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
// Check that the connection is still alive.
ept := endpointTester{c.EP}
- ept.CheckReadError(t, tcpip.ErrWouldBlock)
+ ept.CheckReadError(t, &tcpip.ErrWouldBlock{})
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
@@ -7339,7 +7395,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
),
)
- ept.CheckReadError(t, tcpip.ErrTimeout)
+ ept.CheckReadError(t, &tcpip.ErrTimeout{})
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
@@ -7494,8 +7550,9 @@ func TestTCPDeferAccept(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
+ _, _, err := c.EP.Accept(nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
}
// Send data. This should result in an acceptable endpoint.
@@ -7552,8 +7609,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) {
irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
- if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock)
+ _, _, err := c.EP.Accept(nil)
+ if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
+ t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{})
}
// Sleep for a little of the tcpDeferAccept timeout.
@@ -7675,13 +7733,13 @@ func TestSetStackTimeWaitReuse(t *testing.T) {
s := c.Stack()
testCases := []struct {
v int
- err *tcpip.Error
+ err tcpip.Error
}{
{int(tcpip.TCPTimeWaitReuseDisabled), nil},
{int(tcpip.TCPTimeWaitReuseGlobal), nil},
{int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
- {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
- {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, &tcpip.ErrInvalidOptionValue{}},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, &tcpip.ErrInvalidOptionValue{}},
}
for _, tc := range testCases {
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index b65091c3c..5a9745ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -22,7 +22,6 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
@@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index ee55f030c..b1cb9a324 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -586,7 +586,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
// only endpoint instead of a default dual stack socket.
func (c *Context) CreateV6Endpoint(v6only bool) {
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -689,7 +689,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
- if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted {
+ err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
@@ -749,7 +750,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
// Create creates a TCP endpoint.
func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
@@ -887,7 +888,7 @@ func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
- var err *tcpip.Error
+ var err tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
@@ -903,7 +904,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
err = c.EP.Connect(testFullAddr)
- if err != tcpip.ErrConnectStarted {
+ if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
}
// Receive SYN packet.
@@ -1054,7 +1055,7 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept(nil)
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
select {
case <-ch:
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9f9b3d510..31a5ddce9 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -97,9 +97,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
// state must be read/set using the EndpointState()/setEndpointState()
// methods.
state EndpointState
@@ -111,8 +109,8 @@ type endpoint struct {
multicastNICID tcpip.NICID
portFlags ports.Flags
- lastErrorMu sync.Mutex `state:"nosave"`
- lastError *tcpip.Error `state:".(string)"`
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError tcpip.Error
// Values used to reserve a port or register a transport endpoint.
// (which ever happens first).
@@ -176,18 +174,18 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
// Linux defaults to TTL=1.
multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
- sndBufSizeMax: 32 * 1024,
multicastMemberships: make(map[multicastMembership]struct{}),
state: StateInitial,
uniqueID: s.UniqueID(),
}
- e.ops.InitHandler(e)
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
e.ops.SetMulticastLoop(true)
+ e.ops.SetSendBufferSize(32*1024, false /* notify */)
// Override with stack defaults.
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if err := s.Option(&ss); err == nil {
- e.sndBufSizeMax = ss.Default
+ e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
}
var rs stack.ReceiveBufferSizeOption
@@ -217,7 +215,7 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
-func (e *endpoint) LastError() *tcpip.Error {
+func (e *endpoint) LastError() tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
@@ -227,7 +225,7 @@ func (e *endpoint) LastError() *tcpip.Error {
}
// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
-func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -246,7 +244,7 @@ func (e *endpoint) Close() {
switch e.EndpointState() {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
@@ -284,7 +282,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
}
@@ -292,10 +290,10 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
e.rcvMu.Lock()
if e.rcvList.Empty() {
- err := tcpip.ErrWouldBlock
+ var err tcpip.Error = &tcpip.ErrWouldBlock{}
if e.rcvClosed {
e.stats.ReadErrors.ReadClosed.Increment()
- err = tcpip.ErrClosedForReceive
+ err = &tcpip.ErrClosedForReceive{}
}
e.rcvMu.Unlock()
return tcpip.ReadResult{}, err
@@ -342,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
- return res, tcpip.ErrBadBuffer
+ return res, &tcpip.ErrBadBuffer{}
}
res.Count = n
return res, nil
@@ -353,7 +351,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
-func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.EndpointState() {
case StateInitial:
case StateConnected:
@@ -361,11 +359,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
case StateBound:
if to == nil {
- return false, tcpip.ErrDestinationRequired
+ return false, &tcpip.ErrDestinationRequired{}
}
return false, nil
default:
- return false, tcpip.ErrInvalidEndpointState
+ return false, &tcpip.ErrInvalidEndpointState{}
}
e.mu.RUnlock()
@@ -391,7 +389,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
localAddr := e.ID.LocalAddress
if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -417,18 +415,18 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
n, err := e.write(p, opts)
- switch err {
+ switch err.(type) {
case nil:
e.stats.PacketsSent.Increment()
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
e.stats.WriteErrors.InvalidArgs.Increment()
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
e.stats.WriteErrors.WriteClosed.Increment()
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
e.stats.WriteErrors.InvalidEndpointState.Increment()
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
// Errors indicating any problem with IP routing of the packet.
e.stats.SendErrors.NoRoute.Increment()
default:
@@ -438,14 +436,14 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
if err := e.LastError(); err != nil {
return 0, err
}
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
- return 0, tcpip.ErrInvalidOptionValue
+ return 0, &tcpip.ErrInvalidOptionValue{}
}
to := opts.To
@@ -461,7 +459,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, tcpip.ErrClosedForSend
+ return 0, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
@@ -482,9 +480,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, tcpip.ErrNoRoute
+ return 0, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
@@ -492,7 +493,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
if to.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, tcpip.ErrInvalidEndpointState
+ return 0, &tcpip.ErrInvalidEndpointState{}
}
dst, netProto, err := e.checkV4MappedLocked(*to)
@@ -511,19 +512,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, tcpip.ErrBroadcastDisabled
+ return 0, &tcpip.ErrBroadcastDisabled{}
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
so := e.SocketOptions()
if so.GetRecvError() {
so.QueueLocalErr(
- tcpip.ErrMessageTooLong,
+ &tcpip.ErrMessageTooLong{},
route.NetProto,
header.UDPMaximumPacketSize,
tcpip.FullAddress{
@@ -534,7 +535,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
v,
)
}
- return 0, tcpip.ErrMessageTooLong
+ return 0, &tcpip.ErrMessageTooLong{}
}
ttl := e.ttl
@@ -584,13 +585,13 @@ func (e *endpoint) OnReusePortSet(v bool) {
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.MTUDiscoverOption:
// Return not supported if the value is not disabling path
// MTU discovery.
if v != tcpip.PMTUDiscoveryDont {
- return tcpip.ErrNotSupported
+ return &tcpip.ErrNotSupported{}
}
case tcpip.MulticastTTLOption:
@@ -632,25 +633,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.rcvBufSizeMax = v
e.mu.Unlock()
return nil
- case tcpip.SendBufferSizeOption:
- // Make sure the send buffer size is within the min and max
- // allowed.
- var ss stack.SendBufferSizeOption
- if err := e.stack.Option(&ss); err != nil {
- panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
- }
-
- if v < ss.Min {
- v = ss.Min
- }
- if v > ss.Max {
- v = ss.Max
- }
-
- e.mu.Lock()
- e.sndBufSizeMax = v
- e.mu.Unlock()
- return nil
}
return nil
@@ -661,7 +643,7 @@ func (e *endpoint) HasNIC(id int32) bool {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
@@ -683,17 +665,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
if nic != 0 {
if !e.stack.CheckNIC(nic) {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
} else {
nic = e.stack.CheckLocalAddress(0, netProto, addr)
if nic == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
}
if e.BindNICID != 0 && e.BindNICID != nic {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
e.multicastNICID = nic
@@ -701,7 +683,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.AddMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
nicID := v.NIC
@@ -717,7 +699,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
- return tcpip.ErrUnknownDevice
+ return &tcpip.ErrUnknownDevice{}
}
memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
@@ -726,7 +708,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
defer e.mu.Unlock()
if _, ok := e.multicastMemberships[memToInsert]; ok {
- return tcpip.ErrPortInUse
+ return &tcpip.ErrPortInUse{}
}
if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
@@ -737,7 +719,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
- return tcpip.ErrInvalidOptionValue
+ return &tcpip.ErrInvalidOptionValue{}
}
nicID := v.NIC
@@ -752,7 +734,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr)
}
if nicID == 0 {
- return tcpip.ErrUnknownDevice
+ return &tcpip.ErrUnknownDevice{}
}
memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
@@ -761,7 +743,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
defer e.mu.Unlock()
if _, ok := e.multicastMemberships[memToRemove]; !ok {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil {
@@ -777,7 +759,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.IPv4TOSOption:
e.mu.RLock()
@@ -811,12 +793,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- case tcpip.SendBufferSizeOption:
- e.mu.Lock()
- v := e.sndBufSizeMax
- e.mu.Unlock()
- return v, nil
-
case tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
v := e.rcvBufSizeMax
@@ -830,12 +806,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return v, nil
default:
- return -1, tcpip.ErrUnknownProtocolOption
+ return -1, &tcpip.ErrUnknownProtocolOption{}
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
@@ -846,14 +822,14 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
e.mu.Unlock()
default:
- return tcpip.ErrUnknownProtocolOption
+ return &tcpip.ErrUnknownProtocolOption{}
}
return nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
Data: data,
@@ -903,7 +879,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// checkV4MappedLocked determines the effective network protocol and converts
// addr to its canonical form.
-func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only())
if err != nil {
return tcpip.FullAddress{}, 0, err
@@ -912,7 +888,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
}
// Disconnect implements tcpip.Endpoint.Disconnect.
-func (e *endpoint) Disconnect() *tcpip.Error {
+func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -930,12 +906,12 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
- var err *tcpip.Error
+ var err tcpip.Error
id = stack.TransportEndpointID{
LocalPort: e.ID.LocalPort,
LocalAddress: e.ID.LocalAddress,
}
- id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ id, btd, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
return err
}
@@ -950,7 +926,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.setEndpointState(StateInitial)
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
e.boundBindToDevice = btd
e.route.Release()
@@ -961,10 +937,10 @@ func (e *endpoint) Disconnect() *tcpip.Error {
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
-func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
if addr.Port == 0 {
// We don't support connecting to port zero.
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
e.mu.Lock()
@@ -981,12 +957,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
if nicID != 0 && nicID != e.BindNICID {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
nicID = e.BindNICID
default:
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -1023,7 +999,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
oldPortFlags := e.boundPortFlags
- id, btd, err := e.registerWithStack(nicID, netProtos, id)
+ id, btd, err := e.registerWithStack(netProtos, id)
if err != nil {
r.Release()
return err
@@ -1031,11 +1007,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
e.boundBindToDevice = btd
+ if e.route != nil {
+ // If the endpoint was already connected then make sure we release the
+ // previous route.
+ e.route.Release()
+ }
e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
@@ -1051,20 +1032,20 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
// ConnectEndpoint is not supported.
-func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
- return tcpip.ErrInvalidEndpointState
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
+ return &tcpip.ErrInvalidEndpointState{}
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
-func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
if state := e.EndpointState(); state != StateBound && state != StateConnected {
- return tcpip.ErrNotConnected
+ return &tcpip.ErrNotConnected{}
}
e.shutdownFlags |= flags
@@ -1084,16 +1065,16 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Listen is not supported by UDP, it just fails.
-func (*endpoint) Listen(int) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (*endpoint) Listen(int) tcpip.Error {
+ return &tcpip.ErrNotSupported{}
}
// Accept is not supported by UDP, it just fails.
-func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, tcpip.ErrNotSupported
+func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
+ return nil, nil, &tcpip.ErrNotSupported{}
}
-func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
@@ -1104,7 +1085,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
@@ -1112,11 +1093,11 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
return id, bindToDevice, err
}
-func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
if e.EndpointState() != StateInitial {
- return tcpip.ErrInvalidEndpointState
+ return &tcpip.ErrInvalidEndpointState{}
}
addr, netProto, err := e.checkV4MappedLocked(addr)
@@ -1140,7 +1121,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// A local unicast address was specified, verify that it's valid.
nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nicID == 0 {
- return tcpip.ErrBadLocalAddress
+ return &tcpip.ErrBadLocalAddress{}
}
}
@@ -1148,7 +1129,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
- id, btd, err := e.registerWithStack(nicID, netProtos, id)
+ id, btd, err := e.registerWithStack(netProtos, id)
if err != nil {
return err
}
@@ -1170,7 +1151,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
-func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1186,7 +1167,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
// GetLocalAddress returns the address to which the endpoint is bound.
-func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
@@ -1203,12 +1184,12 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// GetRemoteAddress returns the address to which the endpoint is connected.
-func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.EndpointState() != StateConnected {
- return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
}
return tcpip.FullAddress{
@@ -1341,7 +1322,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
-func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+func (e *endpoint) onICMPError(err tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
// Update last error first.
e.lastErrorMu.Lock()
e.lastError = err
@@ -1397,7 +1378,7 @@ func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt
default:
panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
}
- e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt)
+ e.onICMPError(&tcpip.ErrConnectionRefused{}, errType, errCode, extra, pkt)
return
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 13b72dc88..21a6aa460 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -37,24 +37,6 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) {
u.data = data
}
-// saveLastError is invoked by stateify.
-func (e *endpoint) saveLastError() string {
- if e.lastError == nil {
- return ""
- }
-
- return e.lastError.String()
-}
-
-// loadLastError is invoked by stateify.
-func (e *endpoint) loadLastError(s string) {
- if s == "" {
- return
- }
-
- e.lastError = tcpip.StringToError(s)
-}
-
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
@@ -91,6 +73,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
defer e.mu.Unlock()
e.stack = s
+ e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits)
for m := range e.multicastMemberships {
if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil {
@@ -113,7 +96,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
netProto = header.IPv6ProtocolNumber
}
- var err *tcpip.Error
+ var err tcpip.Error
if state == StateConnected {
e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
@@ -122,7 +105,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
} else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound
// A local unicast address is specified, verify that it's valid.
if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 {
- panic(tcpip.ErrBadLocalAddress)
+ panic(&tcpip.ErrBadLocalAddress{})
}
}
@@ -131,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
// pass it to the reservation machinery.
id := e.ID
e.ID.LocalPort = 0
- e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ e.ID, e.boundBindToDevice, err = e.registerWithStack(e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 49e673d58..705ad1f64 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -69,7 +69,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
}
// CreateEndpoint creates a connected UDP endpoint for the session request.
-func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
netHdr := r.pkt.Network()
route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -77,7 +77,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
+ if err := r.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 91420edd3..427fdd0c9 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -54,13 +54,13 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
@@ -71,7 +71,7 @@ func (*protocol) MinimumPacketSize() int {
// ParsePorts returns the source and destination ports stored in the given udp
// packet.
-func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) {
h := header.UDP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
@@ -94,13 +94,13 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID,
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (*protocol) Option(tcpip.GettableTransportProtocolOption) tcpip.Error {
+ return &tcpip.ErrUnknownProtocolOption{}
}
// Close implements stack.TransportProtocol.Close.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4e2123fe9..64c5298d3 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -353,7 +353,7 @@ func (c *testContext) cleanup() {
func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) {
c.t.Helper()
- var err *tcpip.Error
+ var err tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq)
if err != nil {
c.t.Fatal("NewEndpoint failed: ", err)
@@ -555,11 +555,11 @@ func TestBindToDeviceOption(t *testing.T) {
testActions := []struct {
name string
setBindToDevice *tcpip.NICID
- setBindToDeviceError *tcpip.Error
+ setBindToDeviceError tcpip.Error
getBindToDevice int32
}{
{"GetDefaultValue", nil, nil, 0},
- {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToNonExistent", nicIDPtr(999), &tcpip.ErrUnknownDevice{}, 0},
{"BindToExistent", nicIDPtr(321), nil, 321},
{"UnbindToDevice", nicIDPtr(0), nil, 0},
}
@@ -599,7 +599,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
var buf bytes.Buffer
res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
- if err == tcpip.ErrWouldBlock {
+ if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for data to become available.
select {
case <-ch:
@@ -703,8 +703,11 @@ func TestBindReservedPort(t *testing.T) {
t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
- if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
+ {
+ err := ep.Bind(addr)
+ if _, ok := err.(*tcpip.ErrPortInUse); !ok {
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
+ }
}
}
@@ -716,8 +719,11 @@ func TestBindReservedPort(t *testing.T) {
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
// above, since the endpoint is dual-stack.
- if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
+ {
+ err := ep.Bind(tcpip.FullAddress{Port: addr.Port})
+ if _, ok := err.(*tcpip.ErrPortInUse); !ok {
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", err, &tcpip.ErrPortInUse{})
+ }
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
@@ -806,11 +812,11 @@ func TestV4ReadSelfSource(t *testing.T) {
for _, tt := range []struct {
name string
handleLocal bool
- wantErr *tcpip.Error
+ wantErr tcpip.Error
wantInvalidSource uint64
}{
{"HandleLocal", false, nil, 0},
- {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1},
} {
t.Run(tt.name, func(t *testing.T) {
c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
@@ -959,15 +965,16 @@ func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
// testFailingWrite sends a packet of the given test flow into the UDP endpoint
// and verifies it fails with the provided error code.
-func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
+func testFailingWrite(c *testContext, flow testFlow, wantErr tcpip.Error) {
c.t.Helper()
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- payload := buffer.View(newPayload())
- _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ var r bytes.Reader
+ r.Reset(newPayload())
+ _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
c.checkEndpointWriteStats(1, epstats, gotErr)
@@ -1007,8 +1014,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
}
}
- payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ var r bytes.Reader
+ payload := newPayload()
+ r.Reset(payload)
+ n, err := c.ep.Write(&r, writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %s", err)
}
@@ -1089,7 +1098,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
testWrite(c, unicastV6)
// Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable)
+ testFailingWrite(c, unicastV4in6, &tcpip.ErrNetworkUnreachable{})
const want = 1
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want {
c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want)
@@ -1110,7 +1119,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
testWrite(c, unicastV4in6)
// Write to v6 address.
- testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
+ testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
}
func TestV4WriteOnV6Only(t *testing.T) {
@@ -1120,7 +1129,7 @@ func TestV4WriteOnV6Only(t *testing.T) {
c.createEndpointForFlow(unicastV6Only)
// Write to V4 mapped address.
- testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute)
+ testFailingWrite(c, unicastV4in6, &tcpip.ErrNoRoute{})
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
@@ -1135,7 +1144,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
}
// Write to v6 address.
- testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState)
+ testFailingWrite(c, unicastV6, &tcpip.ErrInvalidEndpointState{})
}
func TestV6WriteOnConnected(t *testing.T) {
@@ -1183,8 +1192,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
writeOpts := tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
}
- payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ var r bytes.Reader
+ payload := newPayload()
+ r.Reset(payload)
+ n, err := c.ep.Write(&r, writeOpts)
if err != nil {
c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
}
@@ -1192,8 +1203,11 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want)
}
- if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused {
- c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ {
+ err := c.ep.LastError()
+ if _, ok := err.(*tcpip.ErrConnectionRefused); !ok {
+ c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err)
+ }
}
})
}
@@ -2303,21 +2317,21 @@ func TestShutdownWrite(t *testing.T) {
t.Fatalf("Shutdown failed: %s", err)
}
- testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
+ testFailingWrite(c, unicastV6, &tcpip.ErrClosedForSend{})
}
-func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err {
+ switch err.(type) {
case nil:
want.PacketsSent.IncrementBy(incr)
- case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue:
+ case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
want.WriteErrors.InvalidArgs.IncrementBy(incr)
- case tcpip.ErrClosedForSend:
+ case *tcpip.ErrClosedForSend:
want.WriteErrors.WriteClosed.IncrementBy(incr)
- case tcpip.ErrInvalidEndpointState:
+ case *tcpip.ErrInvalidEndpointState:
want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
- case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable:
+ case *tcpip.ErrNoRoute, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
want.SendErrors.NoRoute.IncrementBy(incr)
default:
want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
@@ -2327,11 +2341,11 @@ func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportE
}
}
-func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) {
+func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err tcpip.Error) {
got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- switch err {
- case nil, tcpip.ErrWouldBlock:
- case tcpip.ErrClosedForReceive:
+ switch err.(type) {
+ case nil, *tcpip.ErrWouldBlock:
+ case *tcpip.ErrClosedForReceive:
want.ReadErrors.ReadClosed.IncrementBy(incr)
default:
c.t.Errorf("Endpoint error missing stats update err %v", err)
@@ -2497,31 +2511,54 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
defer ep.Close()
- data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ var r bytes.Reader
+ data := []byte{1, 2, 3, 4}
to := tcpip.FullAddress{
Addr: test.remoteAddr,
Port: 80,
}
opts := tcpip.WriteOptions{To: &to}
- expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ expectedErrWithoutBcastOpt := func(err tcpip.Error) tcpip.Error {
+ if _, ok := err.(*tcpip.ErrBroadcastDisabled); ok {
+ return nil
+ }
+ return &tcpip.ErrBroadcastDisabled{}
+ }
if !test.requiresBroadcastOpt {
expectedErrWithoutBcastOpt = nil
}
- if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
+ r.Reset(data)
+ {
+ n, err := ep.Write(&r, opts)
+ if expectedErrWithoutBcastOpt != nil {
+ if want := expectedErrWithoutBcastOpt(err); want != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
+ }
+ } else if err != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
+ }
}
ep.SocketOptions().SetBroadcast(true)
- if n, err := ep.Write(data, opts); err != nil {
+ r.Reset(data)
+ if n, err := ep.Write(&r, opts); err != nil {
t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
}
ep.SocketOptions().SetBroadcast(false)
- if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
- t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
+ r.Reset(data)
+ {
+ n, err := ep.Write(&r, opts)
+ if expectedErrWithoutBcastOpt != nil {
+ if want := expectedErrWithoutBcastOpt(err); want != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, want)
+ }
+ } else if err != nil {
+ t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
+ }
}
})
}