summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/link/nested/BUILD31
-rw-r--r--pkg/tcpip/link/nested/nested.go131
-rw-r--r--pkg/tcpip/link/nested/nested_test.go105
-rw-r--r--pkg/tcpip/link/sniffer/BUILD1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go75
-rw-r--r--pkg/tcpip/stack/iptables.go18
-rw-r--r--pkg/tcpip/stack/iptables_types.go27
-rw-r--r--pkg/tcpip/stack/ndp_test.go157
-rw-r--r--pkg/tcpip/stack/stack.go6
-rw-r--r--pkg/tcpip/stack/stack_test.go2
-rw-r--r--pkg/tcpip/tcpip.go32
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go72
-rw-r--r--pkg/tcpip/transport/tcp/BUILD10
-rw-r--r--pkg/tcpip/transport/tcp/connect.go12
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go16
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go59
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go41
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go10
-rw-r--r--pkg/tcpip/transport/tcp/timer.go1
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go75
-rw-r--r--pkg/tcpip/transport/udp/protocol.go70
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go320
25 files changed, 1005 insertions, 319 deletions
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
new file mode 100644
index 000000000..bdd5276ad
--- /dev/null
+++ b/pkg/tcpip/link/nested/BUILD
@@ -0,0 +1,31 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nested",
+ srcs = [
+ "nested.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "nested_test",
+ size = "small",
+ srcs = [
+ "nested_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
new file mode 100644
index 000000000..2998f9c4f
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nested provides helpers to implement the pattern of nested
+// stack.LinkEndpoints.
+package nested
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher
+// that can be used to implement nesting safely by providing lifecycle
+// concurrency guards.
+//
+// See the tests in this package for example usage.
+type Endpoint struct {
+ child stack.LinkEndpoint
+ embedder stack.NetworkDispatcher
+
+ // mu protects dispatcher.
+ mu sync.RWMutex
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+
+// Init initializes a nested.Endpoint that uses embedder as the dispatcher for
+// child on Attach.
+//
+// See the tests in this package for example usage.
+func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) {
+ e.child = child
+ e.embedder = embedder
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverNetworkPacket(remote, local, protocol, pkt)
+ }
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ e.dispatcher = dispatcher
+ e.mu.Unlock()
+ // If we're attaching to a valid dispatcher, pass embedder as the dispatcher
+ // to our child, otherwise detach the child by giving it a nil dispatcher.
+ var pass stack.NetworkDispatcher
+ if dispatcher != nil {
+ pass = e.embedder
+ }
+ e.child.Attach(pass)
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ e.mu.RLock()
+ isAttached := e.dispatcher != nil
+ e.mu.RUnlock()
+ return isAttached
+}
+
+// MTU implements stack.LinkEndpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.child.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.child.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.child.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.child.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ return e.child.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ return e.child.WritePackets(r, gso, pkts, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return e.child.WriteRawPacket(vv)
+}
+
+// Wait implements stack.LinkEndpoint.
+func (e *Endpoint) Wait() {
+ e.child.Wait()
+}
+
+// GSOMaxSize implements stack.GSOEndpoint.
+func (e *Endpoint) GSOMaxSize() uint32 {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.GSOMaxSize()
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
new file mode 100644
index 000000000..c1a219f02
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nested_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type parentEndpoint struct {
+ nested.Endpoint
+}
+
+var _ stack.LinkEndpoint = (*parentEndpoint)(nil)
+var _ stack.NetworkDispatcher = (*parentEndpoint)(nil)
+
+type childEndpoint struct {
+ stack.LinkEndpoint
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.LinkEndpoint = (*childEndpoint)(nil)
+
+func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ c.dispatcher = dispatcher
+}
+
+func (c *childEndpoint) IsAttached() bool {
+ return c.dispatcher != nil
+}
+
+type counterDispatcher struct {
+ count int
+}
+
+var _ stack.NetworkDispatcher = (*counterDispatcher)(nil)
+
+func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ d.count++
+}
+
+func TestNestedLinkEndpoint(t *testing.T) {
+ const emptyAddress = tcpip.LinkAddress("")
+
+ var (
+ childEP childEndpoint
+ nestedEP parentEndpoint
+ disp counterDispatcher
+ )
+ nestedEP.Endpoint.Init(&childEP, &nestedEP)
+
+ if childEP.IsAttached() {
+ t.Error("On init, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("On init, nestedEP.IsAttached() = true, want = false")
+ }
+
+ nestedEP.Attach(&disp)
+ if disp.count != 0 {
+ t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count)
+ }
+ if !childEP.IsAttached() {
+ t.Error("After attach, childEP.IsAttached() = false, want = true")
+ }
+ if !nestedEP.IsAttached() {
+ t.Error("After attach, nestedEP.IsAttached() = false, want = true")
+ }
+
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ if disp.count != 1 {
+ t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count)
+ }
+
+ nestedEP.Attach(nil)
+ if childEP.IsAttached() {
+ t.Error("After detach, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("After detach, nestedEP.IsAttached() = true, want = false")
+ }
+
+ disp.count = 0
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{})
+ if disp.count != 0 {
+ t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count)
+ }
+
+}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 230a8d53a..7cbc305e7 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index f2e47b6a7..d9cd4e83a 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -48,18 +49,21 @@ var LogPackets uint32 = 1
var LogPacketsToPCAP uint32 = 1
type endpoint struct {
- dispatcher stack.NetworkDispatcher
- lower stack.LinkEndpoint
+ nested.Endpoint
writer io.Writer
maxPCAPLen uint32
}
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.NetworkDispatcher = (*endpoint)(nil)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- return &endpoint{
- lower: lower,
- }
+ sniffer := &endpoint{}
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer
}
func zoneOffset() (int32, error) {
@@ -103,11 +107,12 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
if err := writePCAPHeader(writer, snapLen); err != nil {
return nil, err
}
- return &endpoint{
- lower: lower,
+ sniffer := &endpoint{
writer: writer,
maxPCAPLen: snapLen,
- }, nil
+ }
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
@@ -115,50 +120,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dumpPacket("recv", nil, protocol, pkt)
- e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
-}
-
-// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
-// and registers with the lower endpoint as its dispatcher so that "e" is called
-// for inbound packets.
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
- e.lower.Attach(e)
-}
-
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
-}
-
-// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
-// lower endpoint.
-func (e *endpoint) MTU() uint32 {
- return e.lower.MTU()
-}
-
-// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
-// request to the lower endpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.lower.Capabilities()
-}
-
-// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
-// the request to the lower endpoint.
-func (e *endpoint) MaxHeaderLength() uint16 {
- return e.lower.MaxHeaderLength()
-}
-
-func (e *endpoint) LinkAddress() tcpip.LinkAddress {
- return e.lower.LinkAddress()
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.lower.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
+ e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
@@ -203,7 +165,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw
// forwards the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
e.dumpPacket("send", gso, protocol, pkt)
- return e.lower.WritePacket(r, gso, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements the stack.LinkEndpoint interface. It is called by
@@ -213,7 +175,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.dumpPacket("send", gso, protocol, pkt)
}
- return e.lower.WritePackets(r, gso, pkts, protocol)
+ return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
@@ -221,12 +183,9 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
e.dumpPacket("send", nil, 0, &stack.PacketBuffer{
Data: vv,
})
- return e.lower.WriteRawPacket(vv)
+ return e.Endpoint.WriteRawPacket(vv)
}
-// Wait implements stack.LinkEndpoint.Wait.
-func (e *endpoint) Wait() { e.lower.Wait() }
-
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 4e9b404c8..62d4eb1b6 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -170,17 +170,10 @@ func (it *IPTables) GetTable(name string) (Table, bool) {
func (it *IPTables) ReplaceTable(name string, table Table) {
it.mu.Lock()
defer it.mu.Unlock()
+ it.modified = true
it.tables[name] = table
}
-// ModifyTables acquires write-lock and calls fn with internal name-to-table
-// map. This function can be used to update multiple tables atomically.
-func (it *IPTables) ModifyTables(fn func(map[string]Table)) {
- it.mu.Lock()
- defer it.mu.Unlock()
- fn(it.tables)
-}
-
// GetPriorities returns slice of priorities associated with hook.
func (it *IPTables) GetPriorities(hook Hook) []string {
it.mu.RLock()
@@ -209,6 +202,15 @@ const (
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ it.mu.RLock()
+ if !it.modified {
+ it.mu.RUnlock()
+ return true
+ }
+ it.mu.RUnlock()
+
// Packets are manipulated only if connection and matching
// NAT rule exists.
it.connections.HandlePacket(pkt, hook, gso, r)
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 4a6a5c6f1..7026990c4 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -79,11 +79,11 @@ const (
// IPTables holds all the tables for a netstack.
type IPTables struct {
- // mu protects tables and priorities.
+ // mu protects tables, priorities, and modified.
mu sync.RWMutex
- // tables maps table names to tables. User tables have arbitrary names. mu
- // needs to be locked for accessing.
+ // tables maps table names to tables. User tables have arbitrary names.
+ // mu needs to be locked for accessing.
tables map[string]Table
// priorities maps each hook to a list of table names. The order of the
@@ -91,11 +91,16 @@ type IPTables struct {
// hook. mu needs to be locked for accessing.
priorities map[Hook][]string
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ modified bool
+
connections ConnTrackTable
}
// A Table defines a set of chains and hooks into the network stack. It is
-// really just a list of rules with some metadata for entrypoints and such.
+// really just a list of rules.
type Table struct {
// Rules holds the rules that make up the table.
Rules []Rule
@@ -110,10 +115,6 @@ type Table struct {
// UserChains holds user-defined chains for the keyed by name. Users
// can give their chains arbitrary names.
UserChains map[string]int
-
- // Metadata holds information about the Table that is useful to users
- // of IPTables, but not to the netstack IPTables code itself.
- metadata interface{}
}
// ValidHooks returns a bitmap of the builtin hooks for the given table.
@@ -125,16 +126,6 @@ func (table *Table) ValidHooks() uint32 {
return hooks
}
-// Metadata returns the metadata object stored in table.
-func (table *Table) Metadata() interface{} {
- return table.metadata
-}
-
-// SetMetadata sets the metadata object stored in table.
-func (table *Table) SetMetadata(metadata interface{}) {
- table.metadata = metadata
-}
-
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index ae326b3ab..6f86abc98 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -36,15 +36,24 @@ import (
)
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
- linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
- linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
- defaultTimeout = 100 * time.Millisecond
- defaultAsyncEventTimeout = time.Second
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 10 * time.Second
+
+ // Extra time to use when waiting for an async event to not occur.
+ //
+ // Since a negative check is used to make sure an event did not happen, it is
+ // okay to use a smaller timeout compared to the positive case since execution
+ // stall in regards to the monotonic clock will not affect the expected
+ // outcome.
+ defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -442,7 +451,7 @@ func TestDADResolve(t *testing.T) {
// Make sure the address does not resolve before the resolution time has
// passed.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout)
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
} else if want := (tcpip.AddressWithPrefix{}); addr != want {
@@ -471,7 +480,7 @@ func TestDADResolve(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
@@ -1169,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.routerC:
t.Fatal("should not have received any router events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1252,7 +1261,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
// Rx an RA from lladdr2 with huge lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
@@ -1269,7 +1278,7 @@ func TestRouterDiscovery(t *testing.T) {
// Wait for the normal lifetime plus an extra bit for the
// router to get invalidated. If we don't get an invalidation
// event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout)
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1418,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("should not have received any prefix events")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -1500,7 +1509,7 @@ func TestPrefixDiscovery(t *testing.T) {
if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for prefix discovery event")
}
@@ -1565,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with finite lifetime.
@@ -1590,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After(testInfiniteLifetime + defaultTimeout):
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with a prefix with a lifetime value greater than the
@@ -1599,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout):
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
}
// Receive an RA with 0 lifetime.
@@ -1835,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -1962,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -1975,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -2081,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
@@ -2180,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -2188,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -2265,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2273,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2363,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2398,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
@@ -2407,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event = %+v", e)
}
- case <-time.After(invalidateAfter + defaultAsyncEventTimeout):
+ case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
@@ -2517,7 +2526,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Prefer the prefix again.
@@ -2546,7 +2555,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncEventTimeout):
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
}
// Set the maximum lifetimes for temporary addresses such that on the next
@@ -2556,14 +2565,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) {
// addresses + the time that has already passed since the last address was
// generated so that the regeneration timer is needed to generate the next
// address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
}
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2711,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2724,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -3070,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3110,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3124,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3156,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
@@ -3165,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3295,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3439,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout):
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
}
// Wait for the invalidation event.
@@ -3448,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(2 * defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timeout waiting for addr auto gen event")
}
})
@@ -3509,7 +3518,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
}
@@ -3672,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
}
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -3770,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3837,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3863,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
}
@@ -4030,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4149,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncNegativeEventTimeout):
}
})
}
@@ -4251,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
}
@@ -4277,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncEventTimeout):
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
}
} else {
@@ -4285,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
}
- case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for auto gen addr event")
}
}
@@ -4869,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) {
// Should not get any more events (invalidation timers should have been
// cancelled when the NDP state was cleaned up).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
select {
case <-ndpDisp.routerC:
t.Error("unexpected router event")
@@ -5172,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) {
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
remaining--
}
for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout)
- waitForPkt(2 * defaultAsyncEventTimeout)
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
+ waitForPkt(defaultAsyncPositiveEventTimeout)
} else {
- waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout)
+ waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
}
}
// Make sure no more RS.
if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
} else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout)
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
}
// Make sure the counter got properly
@@ -5305,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stop soliciting routers.
test.stopFn(t, s, true /* first */)
- ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
// A single RS may have been sent before solicitations were stopped.
- ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok = e.ReadContext(ctx); ok {
t.Fatal("should not have sent more than one RS message")
@@ -5319,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
test.stopFn(t, s, false /* first */)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
@@ -5332,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Start soliciting routers.
test.startFn(t, s)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout)
+ waitForPkt(delay + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
@@ -5344,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Starting router solicitations after it has already completed should do
// nothing.
test.startFn(t, s)
- ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after finishing router solicitations")
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index a2190341c..51abe32a7 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1033,14 +1033,14 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
// Remove routes in-place. n tracks the number of routes written.
n := 0
for i, r := range s.routeTable {
+ s.routeTable[i] = tcpip.Route{}
if r.NIC != id {
// Keep this route.
- if i > n {
- s.routeTable[n] = r
- }
+ s.routeTable[n] = r
n++
}
}
+
s.routeTable = s.routeTable[:n]
return nic.remove()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index ffef9bc2c..5aacbf53e 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -3305,7 +3305,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
// Wait for DAD to resolve.
select {
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index b7b227328..956232a44 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -813,6 +813,32 @@ type OutOfBandInlineOption int
// a default TTL.
type DefaultTTLOption uint8
+// StackSACKEnabled is used by stack.(*Stack).TransportProtocolOption to
+// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
+type StackSACKEnabled bool
+
+// StackDelayEnabled is used by stack.(Stack*).TransportProtocolOption to
+// enable/disable Nagle's algorithm in TCP.
+type StackDelayEnabled bool
+
+// StackSendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
+// to get/set the default, min and max send buffer sizes.
+type StackSendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// StackReceiveBufferSizeOption is used by
+// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
+// receive buffer sizes.
+type StackReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+//
// IPPacketInfo is the message struture for IP_PKTINFO.
//
// +stateify savable
@@ -1198,6 +1224,9 @@ type UDPStats struct {
// PacketSendErrors is the number of datagrams failed to be sent.
PacketSendErrors *StatCounter
+
+ // ChecksumErrors is the number of datagrams dropped due to bad checksums.
+ ChecksumErrors *StatCounter
}
// Stats holds statistics about the networking stack.
@@ -1241,6 +1270,9 @@ type ReceiveErrors struct {
// ClosedReceiver is the number of received packets dropped because
// of receiving endpoint state being closed.
ClosedReceiver StatCounter
+
+ // ChecksumErrors is the number of packets dropped due to bad checksums.
+ ChecksumErrors StatCounter
}
// SendErrors collects packet send errors within the transport layer for
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index a406d815e..6a7977259 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,8 @@
package raw
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -66,16 +68,17 @@ type endpoint struct {
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
rcvList rawPacketList
- rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
@@ -103,10 +106,21 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
associated: associated,
}
+ // Override with stack defaults.
+ var ss tcpip.StackSendBufferSizeOption
+ if err := s.TransportProtocolOption(transProto, &ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs tcpip.StackReceiveBufferSizeOption
+ if err := s.TransportProtocolOption(transProto, &rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
// Unassociated endpoints are write-only and users call Write() with IP
// headers included. Because they're write-only, We don't need to
// register with the stack.
@@ -523,7 +537,46 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss tcpip.StackSendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(e.TransProto, &ss); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs tcpip.StackReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(e.TransProto, &rs); err != nil {
+ panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -563,7 +616,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -636,7 +689,6 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvList.PushBack(packet)
e.rcvBufSize += packet.data.Size()
-
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index e26f01fae..6baeda8e4 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -76,7 +76,7 @@ go_library(
)
go_test(
- name = "tcp_test",
+ name = "tcp_x_test",
size = "medium",
srcs = [
"dual_stack_test.go",
@@ -115,3 +115,11 @@ go_test(
"//pkg/tcpip/seqnum",
],
)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 91ee3b0be..377643b82 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -521,7 +521,7 @@ func (h *handshake) execute() *tcpip.Error {
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done()
- var sackEnabled SACKEnabled
+ var sackEnabled tcpip.StackSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
// status then just default to switching off SACK negotiation.
@@ -1516,6 +1516,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
e.workerCleanup = true
if err != nil {
e.resetConnectionLocked(err)
@@ -1568,11 +1569,14 @@ loop:
reuseTW = e.doTimeWait()
}
- // Mark endpoint as closed.
- if e.EndpointState() != StateError {
- e.transitionToStateCloseLocked()
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
}
+ e.transitionToStateCloseLocked()
+
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f225b00e7..10df2bcd5 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -851,12 +851,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
maxSynRetries: DefaultSynRetries,
}
- var ss SendBufferSizeOption
+ var ss tcpip.StackSendBufferSizeOption
if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
e.sndBufSize = ss.Default
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.StackReceiveBufferSizeOption
if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
e.rcvBufSize = rs.Default
}
@@ -871,7 +871,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.rcvAutoParams.disabled = !bool(mrb)
}
- var de DelayEnabled
+ var de tcpip.StackDelayEnabled
if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
e.SetSockOptBool(tcpip.DelayOption, true)
}
@@ -1588,7 +1588,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
- var rs ReceiveBufferSizeOption
+ var rs tcpip.StackReceiveBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if v < rs.Min {
v = rs.Min
@@ -1638,7 +1638,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- var ss SendBufferSizeOption
+ var ss tcpip.StackSendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if v < ss.Min {
v = ss.Min
@@ -1678,7 +1678,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return tcpip.ErrInvalidOptionValue
}
}
- var rs ReceiveBufferSizeOption
+ var rs tcpip.StackReceiveBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
if v < rs.Min/2 {
v = rs.Min / 2
@@ -2609,7 +2609,7 @@ func (e *endpoint) receiveBufferSize() int {
}
func (e *endpoint) maxReceiveBufferSize() int {
- var rs ReceiveBufferSizeOption
+ var rs tcpip.StackReceiveBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
// As a fallback return the hardcoded max buffer size.
return MaxBufferSize
@@ -2690,7 +2690,7 @@ func timeStampOffset() uint32 {
// if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
- var v SACKEnabled
+ var v tcpip.StackSACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index cbb779666..0bebec2d1 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -186,7 +186,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
epState := e.origEndpointState
switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
- var ss SendBufferSizeOption
+ var ss tcpip.StackSendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 73b8a6782..3cff55afa 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -71,29 +71,6 @@ const (
DefaultSynRetries = 6
)
-// SACKEnabled option can be used to enable SACK support in the TCP
-// protocol. See: https://tools.ietf.org/html/rfc2018.
-type SACKEnabled bool
-
-// DelayEnabled option can be used to enable Nagle's algorithm in the TCP protocol.
-type DelayEnabled bool
-
-// SendBufferSizeOption allows the default, min and max send buffer sizes for
-// TCP endpoints to be queried or configured.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
-// ReceiveBufferSizeOption allows the default, min and max receive buffer size
-// for TCP endpoints to be queried or configured.
-type ReceiveBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
const (
ccReno = "reno"
ccCubic = "cubic"
@@ -160,8 +137,8 @@ type protocol struct {
mu sync.RWMutex
sackEnabled bool
delayEnabled bool
- sendBufferSize SendBufferSizeOption
- recvBufferSize ReceiveBufferSizeOption
+ sendBufferSize tcpip.StackSendBufferSizeOption
+ recvBufferSize tcpip.StackReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
@@ -272,19 +249,19 @@ func replyWithReset(s *segment, tos, ttl uint8) {
// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
- case SACKEnabled:
+ case tcpip.StackSACKEnabled:
p.mu.Lock()
p.sackEnabled = bool(v)
p.mu.Unlock()
return nil
- case DelayEnabled:
+ case tcpip.StackDelayEnabled:
p.mu.Lock()
p.delayEnabled = bool(v)
p.mu.Unlock()
return nil
- case SendBufferSizeOption:
+ case tcpip.StackSendBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
@@ -293,7 +270,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case tcpip.StackReceiveBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
@@ -386,25 +363,25 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
- case *SACKEnabled:
+ case *tcpip.StackSACKEnabled:
p.mu.RLock()
- *v = SACKEnabled(p.sackEnabled)
+ *v = tcpip.StackSACKEnabled(p.sackEnabled)
p.mu.RUnlock()
return nil
- case *DelayEnabled:
+ case *tcpip.StackDelayEnabled:
p.mu.RLock()
- *v = DelayEnabled(p.delayEnabled)
+ *v = tcpip.StackDelayEnabled(p.delayEnabled)
p.mu.RUnlock()
return nil
- case *SendBufferSizeOption:
+ case *tcpip.StackSendBufferSizeOption:
p.mu.RLock()
*v = p.sendBufferSize
p.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.StackReceiveBufferSizeOption:
p.mu.RLock()
*v = p.recvBufferSize
p.mu.RUnlock()
@@ -514,8 +491,16 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ sendBufferSize: tcpip.StackSendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: tcpip.StackReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index fcc165f17..812e503bc 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -46,8 +46,8 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enable)); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, StackSACKEnabled(%t) = %s", enable, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index c6ffa7a9d..aca6a7951 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2967,6 +2967,9 @@ loop:
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
}
func TestSendOnResetConnection(t *testing.T) {
@@ -3050,6 +3053,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
}
// TestMaxRTO tests if the retransmit interval caps to MaxRTO.
@@ -3981,7 +3987,10 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -3995,8 +4004,11 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
ep.Close()
@@ -4023,11 +4035,11 @@ func TestMinMaxBufferSizes(t *testing.T) {
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -4754,6 +4766,9 @@ func TestKeepalive(t *testing.T) {
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
@@ -5663,7 +5678,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -5784,7 +5799,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -5926,7 +5941,7 @@ func TestDelayEnabled(t *testing.T) {
checkDelayOption(t, c, false, false) // Delay is disabled by default.
for _, v := range []struct {
- delayEnabled tcp.DelayEnabled
+ delayEnabled tcpip.StackDelayEnabled
wantDelayOption bool
}{
{delayEnabled: false, wantDelayOption: false},
@@ -5941,10 +5956,10 @@ func TestDelayEnabled(t *testing.T) {
}
}
-func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.StackDelayEnabled, wantDelayOption bool) {
t.Helper()
- var gotDelayEnabled tcp.DelayEnabled
+ var gotDelayEnabled tcpip.StackDelayEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
}
@@ -6771,6 +6786,9 @@ func TestTCPUserTimeout(t *testing.T) {
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
}
func TestKeepaliveWithUserTimeout(t *testing.T) {
@@ -6842,6 +6860,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
}
func TestIncreaseWindowOnReceive(t *testing.T) {
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 9721f6caf..9e262c272 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -144,12 +144,12 @@ func New(t *testing.T, mtu uint32) *Context {
})
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Increase minimum RTO in tests to avoid test flakes due to early
@@ -1091,7 +1091,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions
// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
// for the Stack in the context.
func (c *Context) SACKEnabled() bool {
- var v tcp.SACKEnabled
+ var v tcpip.StackSACKEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return false
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index c70525f27..7981d469b 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
+ *t = timer{}
}
// checkExpiration checks if the given timer has actually expired, it should be
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index df5efbf6a..40d66ef09 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,6 +15,8 @@
package udp
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -94,6 +96,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
+ sndBufSizeMax int
state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
@@ -159,7 +162,7 @@ type multicastMembership struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+ e := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -181,10 +184,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
multicastTTL: 1,
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
state: StateInitial,
uniqueID: s.UniqueID(),
}
+
+ // Override with stack defaults.
+ var ss tcpip.StackSendBufferSizeOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs tcpip.StackReceiveBufferSizeOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ return e
}
// UniqueID implements stack.TransportEndpoint.UniqueID.
@@ -611,8 +627,43 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
e.mu.Unlock()
case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs tcpip.StackReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
+ e.mu.Unlock()
+ return nil
case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss tcpip.StackSendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, ss, err))
+ }
+
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -861,7 +912,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -1299,6 +1350,24 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
+
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 4218e7d03..fc93f93c0 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -21,6 +21,7 @@
package udp
import (
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,9 +33,27 @@ import (
const (
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
)
-type protocol struct{}
+type protocol struct {
+ mu sync.RWMutex
+ sendBufferSize tcpip.StackSendBufferSizeOption
+ recvBufferSize tcpip.StackReceiveBufferSizeOption
+}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -183,13 +202,49 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
// SetOption implements stack.TransportProtocol.SetOption.
-func (*protocol) SetOption(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case tcpip.StackSendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.StackReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Option implements stack.TransportProtocol.Option.
-func (*protocol) Option(option interface{}) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *tcpip.StackSendBufferSizeOption:
+ p.mu.RLock()
+ *v = p.sendBufferSize
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.StackReceiveBufferSizeOption:
+ p.mu.RLock()
+ *v = p.recvBufferSize
+ p.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// Close implements stack.TransportProtocol.Close.
@@ -212,5 +267,8 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
- return &protocol{}
+ return &protocol{
+ sendBufferSize: tcpip.StackSendBufferSizeOption{Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize},
+ recvBufferSize: tcpip.StackReceiveBufferSizeOption{Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize},
+ }
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 313a3f117..ff9f60cf9 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -292,15 +292,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -391,17 +391,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -420,16 +424,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -439,17 +437,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
payloadStart := len(buf) - len(payload)
@@ -483,11 +476,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
-
- c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ return buf
}
func newPayload() []byte {
@@ -509,7 +498,7 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
@@ -643,7 +632,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -654,19 +643,19 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
@@ -677,7 +666,7 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
@@ -687,7 +676,7 @@ func TestBindReservedPort(t *testing.T) {
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -697,11 +686,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -714,7 +703,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -729,7 +718,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -744,7 +733,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -759,7 +748,7 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -796,7 +785,10 @@ func TestV4ReadSelfSource(t *testing.T) {
h := unicastV4.header4Tuple(incoming)
h.srcAddr = h.dstAddr
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
@@ -817,7 +809,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -955,7 +947,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -1005,7 +997,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -1022,7 +1014,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -1043,7 +1035,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1070,7 +1062,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1085,7 +1077,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1099,7 +1091,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1234,7 +1226,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1506,12 +1498,12 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOpt failed: %s", err)
}
// Verify multicast interface addr and NIC were set correctly.
@@ -1519,7 +1511,7 @@ func TestMulticastInterfaceOption(t *testing.T) {
ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
+ c.t.Fatalf("GetSockOpt failed: %s", err)
}
if ifoptGot != ifoptWant {
c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
@@ -1691,7 +1683,7 @@ func TestV6UnknownDestination(t *testing.T) {
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1699,20 +1691,25 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
+ // Invalidate the packet length field in the UDP header by adding one.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- var want uint64 = 1
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -1728,7 +1725,6 @@ func TestShortHeader(t *testing.T) {
c.t.Fatalf("Bind failed: %s", err)
}
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
// Allocate a buffer for an IPv6 and too-short UDP header.
@@ -1768,6 +1764,190 @@ func TestShortHeader(t *testing.T) {
}
}
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Invalidate the checksum field in the UDP header by adding one.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Invalidate the checksum field in the UDP header by adding one.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
// TestShutdownRead verifies endpoint read shutdown and error
// stats increment on packet receive.
func TestShutdownRead(t *testing.T) {
@@ -1778,15 +1958,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1809,11 +1989,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)