summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD4
-rw-r--r--pkg/tcpip/stack/conntrack.go236
-rw-r--r--pkg/tcpip/stack/forwarding_test.go23
-rw-r--r--pkg/tcpip/stack/hook_string.go41
-rw-r--r--pkg/tcpip/stack/iptables.go31
-rw-r--r--pkg/tcpip/stack/iptables_targets.go92
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go1156
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go2
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go8
-rw-r--r--pkg/tcpip/stack/nic.go34
-rw-r--r--pkg/tcpip/stack/nic_test.go4
-rw-r--r--pkg/tcpip/stack/packet_buffer.go20
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go140
-rw-r--r--pkg/tcpip/stack/pending_packets.go8
-rw-r--r--pkg/tcpip/stack/registration.go38
-rw-r--r--pkg/tcpip/stack/route.go20
-rw-r--r--pkg/tcpip/stack/stack.go313
-rw-r--r--pkg/tcpip/stack/stack_global_state.go72
-rw-r--r--pkg/tcpip/stack/stack_options.go4
-rw-r--r--pkg/tcpip/stack/stack_test.go89
-rw-r--r--pkg/tcpip/stack/tcp.go451
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go17
-rw-r--r--pkg/tcpip/stack/transport_test.go6
24 files changed, 1705 insertions, 1106 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 49362333a..2bd6a67f5 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -45,6 +45,7 @@ go_library(
"addressable_endpoint_state.go",
"conntrack.go",
"headertype_string.go",
+ "hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
"iptables_state.go",
@@ -66,6 +67,7 @@ go_library(
"stack.go",
"stack_global_state.go",
"stack_options.go",
+ "tcp.go",
"transport_demuxer.go",
"tuple_list.go",
],
@@ -115,6 +117,7 @@ go_test(
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
+ "//pkg/tcpip/testutil",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
@@ -139,6 +142,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
+ "//pkg/tcpip/testutil",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 3f083928f..5720e7543 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -16,6 +16,7 @@ package stack
import (
"encoding/binary"
+ "fmt"
"sync"
"time"
@@ -29,7 +30,7 @@ import (
// The connection is created for a packet if it does not exist. Every
// connection contains two tuples (original and reply). The tuples are
// manipulated if there is a matching NAT rule. The packet is modified by
-// looking at the tuples in the Prerouting and Output hooks.
+// looking at the tuples in each hook.
//
// Currently, only TCP tracking is supported.
@@ -46,12 +47,14 @@ const (
)
// Manipulation type for the connection.
+// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
+// DNAT at the same time.
type manipType int
const (
manipNone manipType = iota
- manipDstPrerouting
- manipDstOutput
+ manipSource
+ manipDestination
)
// tuple holds a connection's identifying and manipulating data in one
@@ -108,6 +111,7 @@ type conn struct {
reply tuple
// manip indicates if the packet should be manipulated. It is immutable.
+ // TODO(gvisor.dev/issue/5696): Support updating manipulation type.
manip manipType
// tcbHook indicates if the packet is inbound or outbound to
@@ -124,6 +128,18 @@ type conn struct {
lastUsed time.Time `state:".(unixTime)"`
}
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
+}
+
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
@@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
}, nil
}
-// newConn creates new connection.
-func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
- conn := conn{
- manip: manip,
- tcbHook: hook,
- lastUsed: time.Now(),
- }
- conn.original = tuple{conn: &conn, tupleID: orig}
- conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
- return &conn
-}
-
func (ct *ConnTrack) init() {
ct.mu.Lock()
defer ct.mu.Unlock()
@@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
return nil
}
- // Create a new connection and change the port as per the iptables
- // rule. This tuple will be used to manipulate the packet in
- // handlePacket.
replyTID := tid.reply()
replyTID.srcAddr = address
replyTID.srcPort = port
- var manip manipType
- switch hook {
- case Prerouting:
- manip = manipDstPrerouting
- case Output:
- manip = manipDstOutput
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
}
- conn := newConn(tid, replyTID, manip, hook)
+ conn = newConn(tid, replyTID, manipDestination, hook)
+ ct.insertConn(conn)
+ return conn
+}
+
+func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Input && hook != Postrouting {
+ return nil
+ }
+
+ replyTID := tid.reply()
+ replyTID.dstAddr = address
+ replyTID.dstPort = port
+
+ conn, _ := ct.connForTID(tid)
+ if conn != nil {
+ // The connection is already tracked.
+ // TODO(gvisor.dev/issue/5696): Support updating an existing connection.
+ return nil
+ }
+ conn = newConn(tid, replyTID, manipSource, hook)
ct.insertConn(conn)
return conn
}
@@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Now that we hold the locks, ensure the tuple hasn't been inserted by
// another thread.
+ // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
alreadyInserted := false
for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
if other.tupleID == conn.original.tupleID {
@@ -343,95 +369,17 @@ func (ct *ConnTrack) insertConn(conn *conn) {
}
}
-// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
-// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
-func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For prerouting redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified.
- switch dir {
- case dirOriginal:
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- case dirReply:
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
- // on inbound packets, so we don't recalculate them. However, we should
- // support cases when they are validated, e.g. when we can't offload
- // receive checksumming.
-
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
-// handlePacketOutput manipulates ports for packets in Output hook.
-func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
- // If this is a noop entry, don't do anything.
- if conn.manip == manipNone {
- return
- }
-
- netHeader := pkt.Network()
- tcpHeader := header.TCP(pkt.TransportHeader().View())
-
- // For output redirection, packets going in the original direction
- // have their destinations modified and replies have their sources
- // modified. For prerouting redirection, we only reach this point
- // when replying, so packet sources are modified.
- if conn.manip == manipDstOutput && dir == dirOriginal {
- port := conn.reply.srcPort
- tcpHeader.SetDestinationPort(port)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
- } else {
- port := conn.original.dstPort
- tcpHeader.SetSourcePort(port)
- netHeader.SetSourceAddress(conn.original.dstAddr)
- }
-
- // Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- if gso != nil && gso.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
- } else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
- }
-
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
-}
-
// handlePacket will manipulate the port and address of the packet if the
// connection exists. Returns whether, after the packet traverses the tables,
// it should create a new entry in the table.
-func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if pkt.NatDone {
return false
}
- if hook != Prerouting && hook != Output {
+ switch hook {
+ case Prerouting, Input, Output, Postrouting:
+ default:
return false
}
@@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
conn, dir := ct.connFor(pkt)
- // Connection or Rule not found for the packet.
+ // Connection not found for the packet.
if conn == nil {
- return true
+ // If this is the last hook in the data path for this packet (Input if
+ // incoming, Postrouting if outgoing), indicate that a connection should be
+ // inserted by the end of this hook.
+ return hook == Input || hook == Postrouting
}
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return false
}
+ // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
+ // validated if checksum offloading is off. It may require IP defrag if the
+ // packets are fragmented.
+
+ switch hook {
+ case Prerouting, Output:
+ if conn.manip == manipDestination {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetDestinationPort(conn.reply.srcPort)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ case dirReply:
+ tcpHeader.SetSourcePort(conn.original.dstPort)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+ pkt.NatDone = true
+ }
+ case Input, Postrouting:
+ if conn.manip == manipSource {
+ switch dir {
+ case dirOriginal:
+ tcpHeader.SetSourcePort(conn.reply.dstPort)
+ netHeader.SetSourceAddress(conn.reply.dstAddr)
+ case dirReply:
+ tcpHeader.SetDestinationPort(conn.original.srcPort)
+ netHeader.SetDestinationAddress(conn.original.srcAddr)
+ }
+ pkt.NatDone = true
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+ if !pkt.NatDone {
+ return false
+ }
+
switch hook {
- case Prerouting:
- handlePacketPrerouting(pkt, conn, dir)
- case Output:
- handlePacketOutput(pkt, conn, gso, r, dir)
+ case Prerouting, Input:
+ case Output, Postrouting:
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ length := uint16(len(tcpHeader) + pkt.Data().Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized hook = %s", hook))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
- pkt.NatDone = true
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
@@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
if conn == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
- } else if conn.manip == manipNone {
- // Unmanipulated connection.
+ } else if conn.manip != manipDestination {
+ // Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index 16ee75bc4..7d3725681 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: vv.ToView().ToVectorisedView(),
})
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets.
_ = r.WriteHeaderIncludedPacket(pkt)
}
@@ -117,7 +117,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu
return f.proto.Number()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
@@ -125,11 +125,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress()[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
+func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
@@ -139,7 +139,7 @@ func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *Packet
return &tcpip.ErrMalformedHeader{}
}
- return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, fwdTestNetNumber, pkt)
}
func (f *fwdTestNetworkEndpoint) Close() {
@@ -264,6 +264,8 @@ type fwdTestPacketInfo struct {
Pkt *PacketBuffer
}
+var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil)
+
type fwdTestLinkEndpoint struct {
dispatcher NetworkDispatcher
mtu uint32
@@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
return caps | CapabilityResolutionRequired
}
-// GSOMaxSize returns the maximum GSO packet size.
-func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
- return 1 << 15
-}
-
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
@@ -322,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -338,10 +335,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N
}
// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- e.WritePacket(r, gso, protocol, pkt)
+ e.WritePacket(r, protocol, pkt)
n++
}
diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go
new file mode 100644
index 000000000..3dc8a7b02
--- /dev/null
+++ b/pkg/tcpip/stack/hook_string.go
@@ -0,0 +1,41 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type Hook ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Prerouting-0]
+ _ = x[Input-1]
+ _ = x[Forward-2]
+ _ = x[Output-3]
+ _ = x[Postrouting-4]
+ _ = x[NumHooks-5]
+}
+
+const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
+
+var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
+
+func (i Hook) String() string {
+ if i >= Hook(len(_Hook_index)-1) {
+ return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 52890f6eb..e2894c548 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -175,9 +175,10 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
- Prerouting: {MangleID, NATID},
- Input: {NATID, FilterID},
- Output: {MangleID, NATID, FilterID},
+ Prerouting: {MangleID, NATID},
+ Input: {NATID, FilterID},
+ Output: {MangleID, NATID, FilterID},
+ Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -266,12 +267,12 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
-// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
+// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from
// which address can be gathered. Currently, address is only needed for
// prerouting.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -285,7 +286,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
// Packets are manipulated only if connection and matching
// NAT rule exists.
- shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
+ shouldTrack := it.connections.handlePacket(pkt, hook, r)
// Go through each table containing the hook.
priorities := it.priorities[hook]
@@ -302,7 +303,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -313,7 +314,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v {
case RuleAccept:
continue
case RuleDrop:
@@ -385,10 +386,10 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok {
+ if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -408,11 +409,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -429,7 +430,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -455,7 +456,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
@@ -478,7 +479,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr)
+ return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr)
}
// OriginalDst returns the original destination of redirected connections. It
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 0e8b90c9b..2812c89aa 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -29,7 +29,7 @@ type AcceptTarget struct {
}
// Action implements Target.Action.
-func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -40,7 +40,7 @@ type DropTarget struct {
}
// Action implements Target.Action.
-func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type ErrorTarget struct {
}
// Action implements Target.Action.
-func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -67,7 +67,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -79,7 +79,7 @@ type ReturnTarget struct {
}
// Action implements Target.Action.
-func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -103,7 +103,7 @@ type RedirectTarget struct {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
-func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
@@ -174,7 +174,85 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, gso, r)
+ ct.handlePacket(pkt, hook, r)
+ }
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
+
+// SNATTarget modifies the source port/IP in the outgoing packets.
+type SNATTarget struct {
+ Addr tcpip.Address
+ Port uint16
+
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
+
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ switch hook {
+ case Postrouting, Input:
+ case Prerouting, Output, Forward:
+ panic(fmt.Sprintf("%s not supported", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ switch protocol := pkt.TransportProtocolNumber; protocol {
+ case header.UDPProtocolNumber:
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ udpHeader.SetChecksum(0)
+ udpHeader.SetSourcePort(st.Port)
+ netHeader := pkt.Network()
+ netHeader.SetSourceAddress(st.Addr)
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.RequiresTXTransportChecksum() {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+ }
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
+ pkt.NatDone = true
+ case header.TCPProtocolNumber:
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
+ ct.handlePacket(pkt, hook, r)
}
default:
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index b0d84befb..4631ab93f 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -345,5 +345,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+ Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 14124ae66..c585b81b2 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -33,15 +33,19 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
+var (
+ addr1 = testutil.MustParse6("a00::1")
+ addr2 = testutil.MustParse6("a00::2")
+ addr3 = testutil.MustParse6("a00::3")
+)
+
const (
- addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
@@ -1142,57 +1146,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
})
}
-// TestNoRouterDiscovery tests that router discovery will not be performed if
-// configured not to.
-func TestNoRouterDiscovery(t *testing.T) {
- // Being configured to discover routers means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // router discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- DiscoverDefaultRouters: discover,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
+func TestDynamicConfigurationsDisabled(t *testing.T) {
+ const (
+ nicID = 1
+ maxRtrSolicitDelay = time.Second
+ )
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ prefix := tcpip.AddressWithPrefix{
+ Address: testutil.MustParse6("102:304:506:708::"),
+ PrefixLen: 64,
+ }
- // Rx an RA with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router when configured not to")
- default:
+ tests := []struct {
+ name string
+ config func(bool) ipv6.NDPConfigurations
+ ra *stack.PacketBuffer
+ }{
+ {
+ name: "No Router Discovery",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
+ },
+ ra: raBuf(llAddr2, 1000),
+ },
+ {
+ name: "No Prefix Discovery",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable}
+ },
+ ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0),
+ },
+ {
+ name: "No Autogenerate Addresses",
+ config: func(enable bool) ipv6.NDPConfigurations {
+ return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable}
+ },
+ ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Being configured to discover routers/prefixes or auto-generate
+ // addresses means RAs must be handled, and router/prefix discovery or
+ // SLAAC must be enabled.
+ //
+ // This tests all possible combinations of the configurations where
+ // router/prefix discovery or SLAAC are disabled.
+ for i := 0; i < 7; i++ {
+ handle := ipv6.HandlingRAsDisabled
+ if i&1 != 0 {
+ handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled
+ }
+ enable := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ ndpConfigs := test.config(enable)
+ ndpConfigs.HandleRAs = handle
+ ndpConfigs.MaxRtrSolicitations = 1
+ ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay
+ ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ Clock: clock,
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })},
+ })
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
+
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding
+ ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err)
+ }
+ stats := ep.Stats()
+ v6Stats, ok := stats.(*ipv6.Stats)
+ if !ok {
+ t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats)
+ }
+
+ // Make sure that when handling RAs are enabled, we solicit routers.
+ clock.Advance(maxRtrSolicitDelay)
+ if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want {
+ t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want)
+ }
+ if handleRAsDisabled {
+ if p, ok := e.Read(); ok {
+ t.Errorf("unexpectedly got a packet = %#v", p)
+ }
+ } else if p, ok := e.Read(); !ok {
+ t.Error("expected router solicitation packet")
+ } else if p.Proto != header.IPv6ProtocolNumber {
+ t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ } else {
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(nil)),
+ )
+ }
+
+ // Make sure we do not discover any routers or prefixes, or perform
+ // SLAAC on reception of an RA.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone())
+ // Make sure that the unhandled RA stat is only incremented when
+ // handling RAs is disabled.
+ if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want {
+ t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
+ }
+ select {
+ case e := <-ndpDisp.routerC:
+ t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e)
+ default:
+ }
+ })
}
})
}
}
+func boolToUint64(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// discovered flag set to discovered.
func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
}
+func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
+ tests := [...]struct {
+ name string
+ handleRAs ipv6.HandleRAsConfiguration
+ forwarding bool
+ }{
+ {
+ name: "Handle RAs when forwarding disabled",
+ handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ forwarding: false,
+ },
+ {
+ name: "Always Handle RAs with forwarding disabled",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ forwarding: false,
+ },
+ {
+ name: "Always Handle RAs with forwarding enabled",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ forwarding: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f(t, test.handleRAs, test.forwarding)
+ })
+ }
+}
+
// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
// remember a discovered router when the dispatcher asks it not to.
func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
@@ -1203,7 +1348,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
},
NDPDisp: &ndpDisp,
@@ -1237,103 +1382,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
}
func TestRouterDiscovery(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
- t.Helper()
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
}
- default:
- t.Fatal("expected router discovery event")
}
- }
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for router discovery event")
}
- case <-time.After(timeout):
- t.Fatal("timed out waiting for router discovery event")
}
- }
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
- default:
- }
-
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- // Rx an RA from lladdr2 with lesser lifetime.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- default:
- }
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ default:
+ }
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
+ expectRouterEvent(llAddr3, true)
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
+ // Rx an RA from lladdr2 with lesser lifetime.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ default:
+ }
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ expectRouterEvent(llAddr2, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+ })
}
// TestRouterDiscoveryMaxRouters tests that only
@@ -1347,7 +1498,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
},
NDPDisp: &ndpDisp,
@@ -1386,57 +1537,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
}
-// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
-// configured not to.
-func TestNoPrefixDiscovery(t *testing.T) {
- prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
- PrefixLen: 64,
- }
-
- // Being configured to discover prefixes means handle and
- // discover are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, discover =
- // true and forwarding = false (the required configuration to do
- // prefix discovery) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- discover := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- DiscoverOnLinkPrefixes: discover,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
-
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix when configured not to")
- default:
- }
- })
- }
-}
-
// Check e to make sure that the event is for prefix on nic with ID 1, and the
// discovered flag set to discovered.
func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
@@ -1455,8 +1555,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: false,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverOnLinkPrefixes: true,
},
NDPDisp: &ndpDisp,
@@ -1494,87 +1593,93 @@ func TestPrefixDiscovery(t *testing.T) {
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
prefix3, subnet3, _ := prefixSubnetAddr(2, "")
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
- t.Helper()
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
}
- default:
- t.Fatal("expected prefix discovery event")
}
- }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
- default:
- }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
- expectPrefixEvent(subnet1, true)
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
+ default:
+ }
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
- expectPrefixEvent(subnet2, true)
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
+ expectPrefixEvent(subnet1, true)
- // Receive an RA with prefix3 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
- expectPrefixEvent(subnet3, true)
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
+ expectPrefixEvent(subnet2, true)
- // Receive an RA with prefix1 in a PI with lifetime = 0.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
- expectPrefixEvent(subnet1, false)
+ // Receive an RA with prefix3 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
+ expectPrefixEvent(subnet3, true)
- // Receive an RA with prefix2 in a PI with lesser lifetime.
- lifetime := uint32(2)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly received prefix event when updating lifetime")
- default:
- }
+ // Receive an RA with prefix1 in a PI with lifetime = 0.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ expectPrefixEvent(subnet1, false)
- // Wait for prefix2's most recent invalidation job plus some buffer to
- // expire.
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ // Receive an RA with prefix2 in a PI with lesser lifetime.
+ lifetime := uint32(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly received prefix event when updating lifetime")
+ default:
}
- case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for prefix discovery event")
- }
- // Receive RA to invalidate prefix3.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
- expectPrefixEvent(subnet3, false)
+ // Wait for prefix2's most recent invalidation job plus some buffer to
+ // expire.
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive RA to invalidate prefix3.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
+ expectPrefixEvent(subnet3, false)
+ })
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
@@ -1590,7 +1695,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}()
prefix := tcpip.AddressWithPrefix{
- Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
}
subnet := prefix.Subnet()
@@ -1603,7 +1708,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverOnLinkPrefixes: true,
},
NDPDisp: &ndpDisp,
@@ -1688,7 +1793,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: false,
DiscoverOnLinkPrefixes: true,
},
@@ -1753,53 +1858,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix)
return containsAddr(list, protocolAddress)
}
-// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
-func TestNoAutoGenAddr(t *testing.T) {
- prefix, _, _ := prefixSubnetAddr(0, "")
-
- // Being configured to auto-generate addresses means handle and
- // autogen are set to true and forwarding is set to false.
- // This tests all possible combinations of the configurations,
- // except for the configuration where handle = true, autogen =
- // true and forwarding = false (the required configuration to do
- // SLAAC) - that will done in other tests.
- for i := 0; i < 7; i++ {
- handle := i&1 != 0
- autogen := i&2 != 0
- forwarding := i&4 == 0
-
- t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handle,
- AutoGenGlobalAddresses: autogen,
- },
- NDPDisp: &ndpDisp,
- })},
- })
- s.SetForwarding(ipv6.ProtocolNumber, forwarding)
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Rx an RA with prefix with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
-
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when configured not to")
- default:
- }
- })
- }
-}
-
// Check e to make sure that the event is for addr on nic with ID 1, and the
// event type is set to eventType.
func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
@@ -1808,7 +1866,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
-func TestAutoGenAddr2(t *testing.T) {
+func TestAutoGenAddr(t *testing.T) {
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
@@ -1820,96 +1878,102 @@ func TestAutoGenAddr2(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
}
- default:
- t.Fatal("expected addr auto gen event")
}
- }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
- default:
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ default:
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
- default:
- }
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ default:
+ }
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
- // Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
- default:
- }
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ default:
+ }
- // Wait for addr of prefix1 to be invalidated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
- t.Fatalf("Should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
- t.Fatalf("Should have %s in the list of addresses", addr2)
- }
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+ })
}
func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
@@ -1997,7 +2061,7 @@ func TestAutoGenTempAddr(t *testing.T) {
RetransmitTimer: test.retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -2298,7 +2362,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -2385,7 +2449,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
@@ -2534,7 +2598,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
RegenAdvanceDuration: newMinVLDuration - regenAfter,
@@ -2735,7 +2799,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
Clock: clock,
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: test.tempAddrs,
AutoGenAddressConflictRetries: 1,
@@ -2880,7 +2944,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: ndpDisp,
@@ -3347,7 +3411,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3490,7 +3554,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3557,7 +3621,7 @@ func TestAutoGenAddrRemoval(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3723,7 +3787,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3805,7 +3869,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
NDPDisp: &ndpDisp,
@@ -3969,7 +4033,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
@@ -3996,7 +4060,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
{
name: "Temporary address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -4146,7 +4210,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
{
name: "Global address",
ndpConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
},
@@ -4274,7 +4338,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
RetransmitTimer: retransmitTimer,
},
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenAddressConflictRetries: maxRetries,
},
@@ -4480,7 +4544,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -4531,7 +4595,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -4625,8 +4689,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) {
}
}
-// TestCleanupNDPState tests that all discovered routers and prefixes, and
-// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
+ const (
+ lifetimeSeconds = 999
+ nicID = 1
+ )
+
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ AutoGenLinkLocal: true,
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ })
+
+ e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
+ if err := s.CreateNIC(nicID, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID)
+ }
+
+ prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1)
+ e1.InjectInbound(
+ header.IPv6ProtocolNumber,
+ raBufWithPI(
+ llAddr3,
+ lifetimeSeconds,
+ prefix,
+ true, /* onLink */
+ true, /* auto */
+ lifetimeSeconds,
+ lifetimeSeconds,
+ ),
+ )
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID)
+ }
+
+ // Enabling or disabling forwarding should not invalidate discovered prefixes
+ // or routers, or auto-generated address.
+ for _, forwarding := range [...]bool{true, false} {
+ t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) {
+ if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
+ select {
+ case e := <-ndpDisp.routerC:
+ t.Errorf("unexpected router event = %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.prefixC:
+ t.Errorf("unexpected prefix event = %#v", e)
+ default:
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("unexpected auto-gen addr event = %#v", e)
+ default:
+ }
+ })
+ }
+}
+
func TestCleanupNDPState(t *testing.T) {
const (
lifetimeSeconds = 5
@@ -4655,18 +4821,6 @@ func TestCleanupNDPState(t *testing.T) {
maxAutoGenAddrEvents int
skipFinalAddrCheck bool
}{
- // A NIC should still keep its auto-generated link-local address when
- // becoming a router.
- {
- name: "Enable forwarding",
- cleanupFn: func(t *testing.T, s *stack.Stack) {
- t.Helper()
- s.SetForwarding(ipv6.ProtocolNumber, true)
- },
- keepAutoGenLinkLocal: true,
- maxAutoGenAddrEvents: 4,
- },
-
// A NIC should cleanup all NDP state when it is disabled.
{
name: "Disable NIC",
@@ -4718,7 +4872,7 @@ func TestCleanupNDPState(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: true,
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
DiscoverDefaultRouters: true,
DiscoverOnLinkPrefixes: true,
AutoGenGlobalAddresses: true,
@@ -4991,7 +5145,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
},
NDPDisp: &ndpDisp,
})},
@@ -5182,96 +5336,127 @@ func TestRouterSolicitation(t *testing.T) {
},
}
+ subTests := []struct {
+ name string
+ handleRAs ipv6.HandleRAsConfiguration
+ afterFirstRS func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Handle RAs when forwarding disabled",
+ handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ afterFirstRS: func(*testing.T, *stack.Stack) {},
+ },
+
+ // Enabling forwarding when RAs are always configured to be handled
+ // should not stop router solicitations.
+ {
+ name: "Handle RAs always",
+ handleRAs: ipv6.HandlingRAsAlwaysEnabled,
+ afterFirstRS: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+ },
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- clock := faketime.NewManualClock()
- e := channelLinkWithHeaderLength{
- Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
- headerLength: test.linkHeaderLen,
- }
- e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ clock.Advance(timeout)
+ p, ok := e.Read()
+ if !ok {
+ t.Fatal("expected router solicitation packet")
+ }
- clock.Advance(timeout)
- p, ok := e.Read()
- if !ok {
- t.Fatal("expected router solicitation packet")
- }
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
- // Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
- t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
- }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
- checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
- checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
- )
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
- if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
- t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
- }
- }
- waitForNothing := func(timeout time.Duration) {
- t.Helper()
+ clock.Advance(timeout)
+ if p, ok := e.Read(); ok {
+ t.Fatalf("unexpectedly got a packet = %#v", p)
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: subTest.handleRAs,
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })},
+ Clock: clock,
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- clock.Advance(timeout)
- if p, ok := e.Read(); ok {
- t.Fatalf("unexpectedly got a packet = %#v", p)
- }
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- MaxRtrSolicitations: test.maxRtrSolicit,
- RtrSolicitationInterval: test.rtrSolicitInt,
- MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
- },
- })},
- Clock: clock,
- })
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
- if addr := test.nicAddr; addr != "" {
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
- }
- }
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay)
+ remaining--
+ }
- // Make sure each RS is sent at the right time.
- remaining := test.maxRtrSolicit
- if remaining > 0 {
- waitForPkt(test.effectiveMaxRtrSolicitDelay)
- remaining--
- }
+ subTest.afterFirstRS(t, s)
- for ; remaining > 0; remaining-- {
- if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
- waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
- waitForPkt(time.Nanosecond)
- } else {
- waitForPkt(test.effectiveRtrSolicitInt)
- }
- }
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond)
+ waitForPkt(time.Nanosecond)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt)
+ }
+ }
- // Make sure no more RS.
- if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
- waitForNothing(test.effectiveRtrSolicitInt)
- } else {
- waitForNothing(test.effectiveMaxRtrSolicitDelay)
- }
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay)
+ }
- if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
- t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
}
})
}
@@ -5362,13 +5547,14 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
MaxRtrSolicitations: maxRtrSolicitations,
RtrSolicitationInterval: interval,
MaxRtrSolicitationDelay: delay,
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index 48bb75e2f..9821a18d3 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
func BenchmarkCacheClear(b *testing.B) {
b.StopTimer()
config := DefaultNUDConfigurations()
- clock := &tcpip.StdClock{}
+ clock := tcpip.NewStdClock()
linkRes := newTestNeighborResolver(nil, config, clock)
linkRes.delay = 0
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index bb2b2d705..1d39ee73d 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -26,14 +26,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
)
const (
entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
entryTestNICID tcpip.NICID = 1
- entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
@@ -44,6 +43,11 @@ const (
entryTestNetDefaultMTU = 65536
)
+var (
+ entryTestAddr1 = testutil.MustParse6("a::1")
+ entryTestAddr2 = testutil.MustParse6("a::2")
+)
+
// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
// time.
func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index ca15c0691..8d615500f 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -316,30 +316,30 @@ func (n *nic) IsLoopback() bool {
}
// WritePacket implements NetworkLinkEndpoint.
-func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
- _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
+func (n *nic) WritePacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+ _, err := n.enqueuePacketBuffer(r, protocol, pkt)
return err
}
-func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (n *nic) writePacketBuffer(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
switch pkt := pkt.(type) {
case *PacketBuffer:
- if err := n.writePacket(r, gso, protocol, pkt); err != nil {
+ if err := n.writePacket(r, protocol, pkt); err != nil {
return 0, err
}
return 1, nil
case *PacketBufferList:
- return n.writePackets(r, gso, protocol, *pkt)
+ return n.writePackets(r, protocol, *pkt)
default:
panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt))
}
}
-func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (n *nic) enqueuePacketBuffer(r *Route, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
routeInfo, _, err := r.resolvedFields(nil)
switch err.(type) {
case nil:
- return n.writePacketBuffer(routeInfo, gso, protocol, pkt)
+ return n.writePacketBuffer(routeInfo, protocol, pkt)
case *tcpip.ErrWouldBlock:
// As per relevant RFCs, we should queue packets while we wait for link
// resolution to complete.
@@ -358,28 +358,27 @@ func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt
// SHOULD be limited to some small value. When a queue overflows, the new
// arrival SHOULD replace the oldest entry. Once address resolution
// completes, the node transmits any queued packets.
- return n.linkResQueue.enqueue(r, gso, protocol, pkt)
+ return n.linkResQueue.enqueue(r, protocol, pkt)
default:
return 0, err
}
}
// WritePacketToRemote implements NetworkInterface.
-func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
var r RouteInfo
r.NetProto = protocol
r.RemoteLinkAddress = remoteLinkAddr
- return n.writePacket(r, gso, protocol, pkt)
+ return n.writePacket(r, protocol, pkt)
}
-func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
+func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
pkt.EgressRoute = r
- pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
- if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil {
return err
}
@@ -389,18 +388,17 @@ func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
}
// WritePackets implements NetworkLinkEndpoint.
-func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
- return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
+func (n *nic) WritePackets(r *Route, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) {
+ return n.enqueuePacketBuffer(r, protocol, &pkts)
}
-func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
+func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
pkt.EgressRoute = r
- pkt.GSOOptions = gso
pkt.NetworkProtocolNumber = protocol
}
- writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol)
n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index c0f956e53..8a3005295 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -65,12 +65,12 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
}
// WritePacket implements NetworkEndpoint.WritePacket.
-func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
+func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error {
return nil
}
// WritePackets implements NetworkEndpoint.WritePackets.
-func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
+func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) {
// Our tests don't use this so we don't support it.
return 0, &tcpip.ErrNotSupported{}
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 8f288675d..9527416cf 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -103,7 +103,7 @@ type PacketBuffer struct {
// The following fields are only set by the qdisc layer when the packet
// is added to a queue.
EgressRoute RouteInfo
- GSOOptions *GSO
+ GSOOptions GSO
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
@@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network {
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
- return NewPacketBuffer(PacketBufferOptions{
+ newPk := NewPacketBuffer(PacketBufferOptions{
Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
})
+ // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
+ // maintain this flag in the packet. Currently conntrack needs this flag to
+ // tell if a noop connection should be inserted at Input hook. Once conntrack
+ // redefines the manipulation field as mutable, we won't need the special noop
+ // connection.
+ if pk.NatDone {
+ newPk.NatDone = true
+ }
+ return newPk
}
// headerInfo stores metadata about a header in a packet.
@@ -355,9 +364,10 @@ func (d PacketData) PullUp(size int) (buffer.View, bool) {
return d.pk.data.PullUp(size)
}
-// TrimFront removes count from the beginning of d. It panics if count >
-// d.Size().
-func (d PacketData) TrimFront(count int) {
+// DeleteFront removes count from the beginning of d. It panics if count >
+// d.Size(). All backing storage references after the front of the d are
+// invalidated.
+func (d PacketData) DeleteFront(count int) {
d.pk.data.TrimFront(count)
}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 6728370c3..bd4eb4fed 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) {
if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, test.data)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
- concatViews(test.link, test.network, test.transport, test.data))
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(test.link, test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(test.network, test.transport, test.data))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(test.transport, test.data))
+ // Check the after state.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.link,
+ network: test.network,
+ transport: test.transport,
+ data: test.data,
+ })
})
}
}
@@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) {
if got, want := pk.Size(), len(test.data); got != want {
t.Errorf("After pk.Size() = %d, want %d", got, want)
}
- // After state of pk.
- var (
- link = test.data[:test.link]
- network = test.data[test.link:][:test.network]
- transport = test.data[test.link+test.network:][:test.transport]
- payload = test.data[allHdrSize:]
- )
- checkData(t, pk, payload)
- checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
- // Check the after values for each header.
- checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
- checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
- checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
- // Check the after values for PayloadSince.
- checkViewEqual(t, "After PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()),
- concatViews(link, network, transport, payload))
- checkViewEqual(t, "After PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()),
- concatViews(network, transport, payload))
- checkViewEqual(t, "After PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()),
- concatViews(transport, payload))
+ // Check the after state of pk.
+ checkPacketContents(t, "After ", pk, packetContents{
+ link: test.data[:test.link],
+ network: test.data[test.link:][:test.network],
+ transport: test.data[test.link+test.network:][:test.transport],
+ data: test.data[allHdrSize:],
+ })
})
}
}
@@ -252,6 +226,39 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
})
}
+// This is a very obscure use-case seen in the code that verifies packets
+// before sending them out. It tries to parse the headers to verify.
+// PacketHeader was initially not designed to mix Push() and Consume(), but it
+// works and it's been relied upon. Include a test here.
+func TestPacketHeaderPushConsumeMixed(t *testing.T) {
+ link := makeView(10)
+ network := makeView(20)
+ data := makeView(30)
+
+ initData := append([]byte(nil), network...)
+ initData = append(initData, data...)
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: len(link),
+ Data: buffer.NewViewFromBytes(initData).ToVectorisedView(),
+ })
+
+ // 1. Consume network header
+ gotNetwork, ok := pk.NetworkHeader().Consume(len(network))
+ if !ok {
+ t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network))
+ }
+ checkViewEqual(t, "gotNetwork", gotNetwork, network)
+
+ // 2. Push link header
+ copy(pk.LinkHeader().Push(len(link)), link)
+
+ checkPacketContents(t, "" /* prefix */, pk, packetContents{
+ link: link,
+ network: network,
+ data: data,
+ })
+}
+
func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
const headerSize = 10
@@ -397,11 +404,11 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // TrimFront
+ // DeleteFront
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().TrimFront(n)
+ pkt.Data().DeleteFront(n)
checkData(t, pkt, []byte(tc.data)[n:])
})
@@ -494,6 +501,37 @@ func TestPacketBufferData(t *testing.T) {
}
}
+type packetContents struct {
+ link buffer.View
+ network buffer.View
+ transport buffer.View
+ data buffer.View
+}
+
+func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) {
+ t.Helper()
+ // Headers.
+ checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link)
+ checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network)
+ checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport)
+ // Data.
+ checkData(t, pk, want.data)
+ // Whole packet.
+ checkViewEqual(t, prefix+"pk.Views()",
+ concatViews(pk.Views()...),
+ concatViews(want.link, want.network, want.transport, want.data))
+ // PayloadSince.
+ checkViewEqual(t, prefix+"PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(want.link, want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(want.network, want.transport, want.data))
+ checkViewEqual(t, prefix+"PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(want.transport, want.data))
+}
+
func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
t.Helper()
reserved := opts.ReserveHeaderBytes
@@ -510,19 +548,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO
if got, want := pk.Size(), len(data); got != want {
t.Errorf("Initial pk.Size() = %d, want %d", got, want)
}
- checkData(t, pk, data)
- checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
- // Check the initial values for each header.
- checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
- checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
- checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
- // Check the initial valies for PayloadSince.
- checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
- PayloadSince(pk.LinkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
- PayloadSince(pk.NetworkHeader()), data)
- checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
- PayloadSince(pk.TransportHeader()), data)
+ checkPacketContents(t, "Initial ", pk, packetContents{
+ data: data,
+ })
}
func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index e936aa728..13e8907ec 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -46,7 +46,6 @@ func (p *PacketBufferList) len() int {
type pendingPacket struct {
routeInfo RouteInfo
- gso *GSO
proto tcpip.NetworkProtocolNumber
pkt pendingPacketBuffer
}
@@ -119,7 +118,7 @@ func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpi
// If the maximum number of pending resolutions is reached, the packets
// associated with the oldest link resolution will be dequeued as if they failed
// link resolution.
-func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
+func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) {
f.mu.Lock()
// Make sure we attempt resolution while holding f's lock so that we avoid
// a race where link resolution completes before we enqueue the packets.
@@ -137,7 +136,7 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N
// The route resolved immediately, so we don't need to wait for link
// resolution to send the packet.
f.mu.Unlock()
- return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt)
+ return f.nic.writePacketBuffer(routeInfo, proto, pkt)
case *tcpip.ErrWouldBlock:
// We need to wait for link resolution to complete.
default:
@@ -150,7 +149,6 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N
packets, ok := f.mu.packets[ch]
packets = append(packets, pendingPacket{
routeInfo: routeInfo,
- gso: gso,
proto: proto,
pkt: pkt,
})
@@ -211,7 +209,7 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l
for _, p := range packets {
if err == nil {
p.routeInfo.RemoteLinkAddress = linkAddr
- _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
+ _, _ = f.nic.writePacketBuffer(p.routeInfo, p.proto, p.pkt)
} else {
f.incrementOutgoingPacketErrors(p.proto, p.pkt)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ff3a385e1..e26225552 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -537,14 +537,14 @@ type NetworkInterface interface {
CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool
// WritePacketToRemote writes the packet to the given remote link address.
- WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
+ WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePacket writes a packet with the given protocol through the given
// route.
//
// WritePacket takes ownership of the packet buffer. The packet buffer's
// network and transport header must be set.
- WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
+ WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol through the given
// route. Must not be called with an empty list of packet buffers.
@@ -554,7 +554,7 @@ type NetworkInterface interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+ WritePackets(*Route, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
// HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP
// request or NDP Neighbor Solicitation).
@@ -610,12 +610,12 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It takes ownership of pkt. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
+ WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length. It takes ownership of pkts and
// underlying packets.
- WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
+ WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address. It takes ownership of pkt.
@@ -756,11 +756,6 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityHardwareGSO
-
- // CapabilitySoftwareGSO indicates the link endpoint supports of sending
- // multiple packets using a single call (LinkEndpoint.WritePackets).
- CapabilitySoftwareGSO
)
// NetworkLinkEndpoint is a data-link layer that supports sending network
@@ -832,7 +827,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
+ WritePacket(RouteInfo, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error
// WritePackets writes packets with the given protocol and route. Must not be
// called with an empty list of packet buffers.
@@ -842,7 +837,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, syscall filters
// may need to be updated.
- WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
+ WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -1047,10 +1042,29 @@ type GSO struct {
MaxSize uint32
}
+// SupportedGSO returns the type of segmentation offloading supported.
+type SupportedGSO int
+
+const (
+ // GSONotSupported indicates that segmentation offloading is not supported.
+ GSONotSupported SupportedGSO = iota
+
+ // HWGSOSupported indicates that segmentation offloading may be performed by
+ // the hardware.
+ HWGSOSupported
+
+ // SWGSOSupported indicates that segmentation offloading may be performed in
+ // software.
+ SWGSOSupported
+)
+
// GSOEndpoint provides access to GSO properties.
type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
+
+ // SupportedGSO returns the supported segmentation offloading.
+ SupportedGSO() SupportedGSO
}
// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 39344808d..8a044c073 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
localAddr = addressEndpoint.AddressWithPrefix().Address
}
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) {
addressEndpoint.DecRef()
return nil
}
@@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool {
// HasSoftwareGSOCapability returns true if the route supports software GSO.
func (r *Route) HasSoftwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == SWGSOSupported
+ }
+ return false
}
// HasHardwareGSOCapability returns true if the route supports hardware GSO.
func (r *Route) HasHardwareGSOCapability() bool {
- return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
+ return gso.SupportedGSO() == HWGSOSupported
+ }
+ return false
}
// HasSaveRestoreCapability returns true if the route supports save/restore.
@@ -448,22 +454,22 @@ func (r *Route) isValidForOutgoingRLocked() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
+func (r *Route) WritePacket(params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error {
if !r.isValidForOutgoing() {
return &tcpip.ErrInvalidEndpointState{}
}
- return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
-func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
+func (r *Route) WritePackets(pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) {
if !r.isValidForOutgoing() {
return 0, &tcpip.ErrInvalidEndpointState{}
}
- return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 931a97ddc..436392f23 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -35,7 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
- "gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -56,306 +55,6 @@ type transportProtocolState struct {
defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
-// TCPProbeFunc is the expected function type for a TCP probe function to be
-// passed to stack.AddTCPProbe.
-type TCPProbeFunc func(s TCPEndpointState)
-
-// TCPCubicState is used to hold a copy of the internal cubic state when the
-// TCPProbeFunc is invoked.
-type TCPCubicState struct {
- WLastMax float64
- WMax float64
- T time.Time
- TimeSinceLastCongestion time.Duration
- C float64
- K float64
- Beta float64
- WC float64
- WEst float64
-}
-
-// TCPRACKState is used to hold a copy of the internal RACK state when the
-// TCPProbeFunc is invoked.
-type TCPRACKState struct {
- XmitTime time.Time
- EndSequence seqnum.Value
- FACK seqnum.Value
- RTT time.Duration
- Reord bool
- DSACKSeen bool
- ReoWnd time.Duration
- ReoWndIncr uint8
- ReoWndPersist int8
- RTTSeq seqnum.Value
-}
-
-// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
-type TCPEndpointID struct {
- // LocalPort is the local port associated with the endpoint.
- LocalPort uint16
-
- // LocalAddress is the local [network layer] address associated with
- // the endpoint.
- LocalAddress tcpip.Address
-
- // RemotePort is the remote port associated with the endpoint.
- RemotePort uint16
-
- // RemoteAddress it the remote [network layer] address associated with
- // the endpoint.
- RemoteAddress tcpip.Address
-}
-
-// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
-// TCP endpoint.
-type TCPFastRecoveryState struct {
- // Active if true indicates the endpoint is in fast recovery.
- Active bool
-
- // First is the first unacknowledged sequence number being recovered.
- First seqnum.Value
-
- // Last is the 'recover' sequence number that indicates the point at
- // which we should exit recovery barring any timeouts etc.
- Last seqnum.Value
-
- // MaxCwnd is the maximum value we are permitted to grow the congestion
- // window during recovery. This is set at the time we enter recovery.
- MaxCwnd int
-
- // HighRxt is the highest sequence number which has been retransmitted
- // during the current loss recovery phase.
- // See: RFC 6675 Section 2 for details.
- HighRxt seqnum.Value
-
- // RescueRxt is the highest sequence number which has been
- // optimistically retransmitted to prevent stalling of the ACK clock
- // when there is loss at the end of the window and no new data is
- // available for transmission.
- // See: RFC 6675 Section 2 for details.
- RescueRxt seqnum.Value
-}
-
-// TCPReceiverState holds a copy of the internal state of the receiver for
-// a given TCP endpoint.
-type TCPReceiverState struct {
- // RcvNxt is the TCP variable RCV.NXT.
- RcvNxt seqnum.Value
-
- // RcvAcc is the TCP variable RCV.ACC.
- RcvAcc seqnum.Value
-
- // RcvWndScale is the window scaling to use for inbound segments.
- RcvWndScale uint8
-
- // PendingBufUsed is the number of bytes pending in the receive
- // queue.
- PendingBufUsed int
-}
-
-// TCPSenderState holds a copy of the internal state of the sender for
-// a given TCP Endpoint.
-type TCPSenderState struct {
- // LastSendTime is the time at which we sent the last segment.
- LastSendTime time.Time
-
- // DupAckCount is the number of Duplicate ACK's received.
- DupAckCount int
-
- // SndCwnd is the size of the sending congestion window in packets.
- SndCwnd int
-
- // Ssthresh is the slow start threshold in packets.
- Ssthresh int
-
- // SndCAAckCount is the number of packets consumed in congestion
- // avoidance mode.
- SndCAAckCount int
-
- // Outstanding is the number of packets in flight.
- Outstanding int
-
- // SackedOut is the number of packets which have been selectively acked.
- SackedOut int
-
- // SndWnd is the send window size in bytes.
- SndWnd seqnum.Size
-
- // SndUna is the next unacknowledged sequence number.
- SndUna seqnum.Value
-
- // SndNxt is the sequence number of the next segment to be sent.
- SndNxt seqnum.Value
-
- // RTTMeasureSeqNum is the sequence number being used for the latest RTT
- // measurement.
- RTTMeasureSeqNum seqnum.Value
-
- // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
- RTTMeasureTime time.Time
-
- // Closed indicates that the caller has closed the endpoint for sending.
- Closed bool
-
- // SRTT is the smoothed round-trip time as defined in section 2 of
- // RFC 6298.
- SRTT time.Duration
-
- // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
- RTO time.Duration
-
- // RTTVar is the round-trip time variation as defined in section 2 of
- // RFC 6298.
- RTTVar time.Duration
-
- // SRTTInited if true indicates take a valid RTT measurement has been
- // completed.
- SRTTInited bool
-
- // MaxPayloadSize is the maximum size of the payload of a given segment.
- // It is initialized on demand.
- MaxPayloadSize int
-
- // SndWndScale is the number of bits to shift left when reading the send
- // window size from a segment.
- SndWndScale uint8
-
- // MaxSentAck is the highest acknowledgement number sent till now.
- MaxSentAck seqnum.Value
-
- // FastRecovery holds the fast recovery state for the endpoint.
- FastRecovery TCPFastRecoveryState
-
- // Cubic holds the state related to CUBIC congestion control.
- Cubic TCPCubicState
-
- // RACKState holds the state related to RACK loss detection algorithm.
- RACKState TCPRACKState
-}
-
-// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
-type TCPSACKInfo struct {
- // Blocks is the list of SACK Blocks that identify the out of order segments
- // held by a given TCP endpoint.
- Blocks []header.SACKBlock
-
- // ReceivedBlocks are the SACK blocks received by this endpoint
- // from the peer endpoint.
- ReceivedBlocks []header.SACKBlock
-
- // MaxSACKED is the highest sequence number that has been SACKED
- // by the peer.
- MaxSACKED seqnum.Value
-}
-
-// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
-type RcvBufAutoTuneParams struct {
- // MeasureTime is the time at which the current measurement
- // was started.
- MeasureTime time.Time
-
- // CopiedBytes is the number of bytes copied to user space since
- // this measure began.
- CopiedBytes int
-
- // PrevCopiedBytes is the number of bytes copied to userspace in
- // the previous RTT period.
- PrevCopiedBytes int
-
- // RcvBufSize is the auto tuned receive buffer size.
- RcvBufSize int
-
- // RTT is the smoothed RTT as measured by observing the time between
- // when a byte is first acknowledged and the receipt of data that is at
- // least one window beyond the sequence number that was acknowledged.
- RTT time.Duration
-
- // RTTVar is the "round-trip time variation" as defined in section 2
- // of RFC6298.
- RTTVar time.Duration
-
- // RTTMeasureSeqNumber is the highest acceptable sequence number at the
- // time this RTT measurement period began.
- RTTMeasureSeqNumber seqnum.Value
-
- // RTTMeasureTime is the absolute time at which the current RTT
- // measurement period began.
- RTTMeasureTime time.Time
-
- // Disabled is true if an explicit receive buffer is set for the
- // endpoint.
- Disabled bool
-}
-
-// TCPEndpointState is a copy of the internal state of a TCP endpoint.
-type TCPEndpointState struct {
- // ID is a copy of the TransportEndpointID for the endpoint.
- ID TCPEndpointID
-
- // SegTime denotes the absolute time when this segment was received.
- SegTime time.Time
-
- // RcvBufSize is the size of the receive socket buffer for the endpoint.
- RcvBufSize int
-
- // RcvBufUsed is the amount of bytes actually held in the receive socket
- // buffer for the endpoint.
- RcvBufUsed int
-
- // RcvBufAutoTuneParams is used to hold state variables to compute
- // the auto tuned receive buffer size.
- RcvAutoParams RcvBufAutoTuneParams
-
- // RcvClosed if true, indicates the endpoint has been closed for reading.
- RcvClosed bool
-
- // SendTSOk is used to indicate when the TS Option has been negotiated.
- // When sendTSOk is true every non-RST segment should carry a TS as per
- // RFC7323#section-1.1.
- SendTSOk bool
-
- // RecentTS is the timestamp that should be sent in the TSEcr field of
- // the timestamp for future segments sent by the endpoint. This field is
- // updated if required when a new segment is received by this endpoint.
- RecentTS uint32
-
- // TSOffset is a randomized offset added to the value of the TSVal field
- // in the timestamp option.
- TSOffset uint32
-
- // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
- // option in the SYN/SYN-ACK.
- SACKPermitted bool
-
- // SACK holds TCP SACK related information for this endpoint.
- SACK TCPSACKInfo
-
- // SndBufSize is the size of the socket send buffer.
- SndBufSize int
-
- // SndBufUsed is the number of bytes held in the socket send buffer.
- SndBufUsed int
-
- // SndClosed indicates that the endpoint has been closed for sends.
- SndClosed bool
-
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
- // PacketTooBigCount is used to notify the main protocol routine how
- // many times a "packet too big" control packet is received.
- PacketTooBigCount int
-
- // SndMTU is the smallest MTU seen in the control packets received.
- SndMTU int
-
- // Receiver holds variables related to the TCP receiver for the endpoint.
- Receiver TCPReceiverState
-
- // Sender holds state related to the TCP Sender for the endpoint.
- Sender TCPSenderState
-}
-
// ResumableEndpoint is an endpoint that needs to be resumed after restore.
type ResumableEndpoint interface {
// Resume resumes an endpoint after restore. This can be used to restart
@@ -455,7 +154,7 @@ type Stack struct {
// receiveBufferSize holds the min/default/max receive buffer sizes for
// endpoints other than TCP.
- receiveBufferSize ReceiveBufferSizeOption
+ receiveBufferSize tcpip.ReceiveBufferSizeOption
// tcpInvalidRateLimit is the maximal rate for sending duplicate
// acknowledgements in response to incoming TCP packets that are for an existing
@@ -623,7 +322,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {}
func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
- clock = &tcpip.StdClock{}
+ clock = tcpip.NewStdClock()
}
if opts.UniqueID == nil {
@@ -669,7 +368,7 @@ func New(opts Options) *Stack {
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
},
- receiveBufferSize: ReceiveBufferSizeOption{
+ receiveBufferSize: tcpip.ReceiveBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
@@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
+ isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
@@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
@@ -1874,7 +1573,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress,
ReserveHeaderBytes: int(nic.MaxHeaderLength()),
Data: payload,
})
- return nic.WritePacketToRemote(remote, nil, netProto, pkt)
+ return nic.WritePacketToRemote(remote, netProto, pkt)
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
index dfec4258a..33824afd0 100644
--- a/pkg/tcpip/stack/stack_global_state.go
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -14,6 +14,78 @@
package stack
+import "time"
+
// StackFromEnv is the global stack created in restore run.
// FIXME(b/36201077)
var StackFromEnv *Stack
+
+// saveT is invoked by stateify.
+func (t *TCPCubicState) saveT() unixTime {
+ return unixTime{t.T.Unix(), t.T.UnixNano()}
+}
+
+// loadT is invoked by stateify.
+func (t *TCPCubicState) loadT(unix unixTime) {
+ t.T = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (t *TCPRACKState) saveXmitTime() unixTime {
+ return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (t *TCPRACKState) loadXmitTime(unix unixTime) {
+ t.XmitTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveLastSendTime is invoked by stateify.
+func (t *TCPSenderState) saveLastSendTime() unixTime {
+ return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (t *TCPSenderState) loadLastSendTime(unix unixTime) {
+ t.LastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRTTMeasureTime is invoked by stateify.
+func (t *TCPSenderState) saveRTTMeasureTime() unixTime {
+ return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()}
+}
+
+// loadRTTMeasureTime is invoked by stateify.
+func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) {
+ t.RTTMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime {
+ return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()}
+}
+
+// loadMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) {
+ r.MeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRTTMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime {
+ return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()}
+}
+
+// loadRTTMeasureTime is invoked by stateify.
+func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) {
+ r.RTTMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveSegTime is invoked by stateify.
+func (t *TCPEndpointState) saveSegTime() unixTime {
+ return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()}
+}
+
+// loadSegTime is invoked by stateify.
+func (t *TCPEndpointState) loadSegTime(unix unixTime) {
+ t.SegTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
index 3066f4ffd..80e8e0089 100644
--- a/pkg/tcpip/stack/stack_options.go
+++ b/pkg/tcpip/stack/stack_options.go
@@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error {
s.mu.Unlock()
return nil
- case ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
@@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error {
s.mu.RUnlock()
return nil
- case *ReceiveBufferSizeOption:
+ case *tcpip.ReceiveBufferSizeOption:
s.mu.RLock()
*v = s.receiveBufferSize
s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 2814b94b4..d2c40cc43 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -39,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
@@ -137,11 +138,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- nb, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
if !ok {
return
}
- pkt.Data().TrimFront(fakeNetHeaderLen)
+ // DeleteFront invalidates slices. Make a copy before trimming.
+ nb := append([]byte(nil), hdr...)
+ pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
@@ -170,7 +173,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe
return f.proto.Number()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++
@@ -189,11 +192,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
+func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) {
panic("not implemented")
}
@@ -436,7 +439,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error
}
func send(r *stack.Route, payload buffer.View) tcpip.Error {
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(r.MaxHeaderLength()),
Data: payload.ToVectorisedView(),
}))
@@ -1461,7 +1464,7 @@ func TestExternalSendWithHandleLocal(t *testing.T) {
if n := ep.Drain(); n != 0 {
t.Fatalf("got ep.Drain() = %d, want = 0", n)
}
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: fakeTransNumber,
TTL: 123,
TOS: stack.DefaultTOS,
@@ -1645,10 +1648,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
// Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
- nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ nic1Gateway := testutil.MustParse4("192.168.1.1")
// Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
- nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
+ nic2Gateway := testutil.MustParse4("10.10.10.1")
// Create a new stack with two NICs.
s := stack.New(stack.Options{
@@ -2789,25 +2792,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
const (
- linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
- ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01")
- ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02")
- toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
-
nicID = 1
lifetimeSeconds = 9999
)
+ var (
+ linkLocalAddr1 = testutil.MustParse6("fe80::1")
+ linkLocalAddr2 = testutil.MustParse6("fe80::2")
+ linkLocalMulticastAddr = testutil.MustParse6("ff02::1")
+ uniqueLocalAddr1 = testutil.MustParse6("fc00::1")
+ uniqueLocalAddr2 = testutil.MustParse6("fd00::2")
+ globalAddr1 = testutil.MustParse6("a000::1")
+ globalAddr2 = testutil.MustParse6("a000::2")
+ globalAddr3 = testutil.MustParse6("a000::3")
+ ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1")
+ ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2")
+ toredoAddr1 = testutil.MustParse6("2001::1")
+ toredoAddr2 = testutil.MustParse6("2001::2")
+ ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1")
+ ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2")
+ )
+
prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1)
@@ -3017,7 +3022,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: true,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
@@ -3354,21 +3359,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
const sMin = stack.MinBufferSize
testCases := []struct {
name string
- rs stack.ReceiveBufferSizeOption
+ rs tcpip.ReceiveBufferSizeOption
err tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
- {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
- {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
+ {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}},
+ {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}},
// Valid Configurations
- {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -3377,7 +3382,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) {
if err := s.SetOption(tc.rs); err != tc.err {
t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
}
- var rs stack.ReceiveBufferSizeOption
+ var rs tcpip.ReceiveBufferSizeOption
if tc.err == nil {
if err := s.Option(&rs); err != nil {
t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
@@ -3448,7 +3453,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
ipv4Subnet := ipv4Addr.Subnet()
ipv4SubnetBcast := ipv4Subnet.Broadcast()
- ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4Gateway := testutil.MustParse4("192.168.1.1")
ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
Address: "\xc0\xa8\x01\x3a",
PrefixLen: 31,
@@ -4352,13 +4357,15 @@ func TestWritePacketToRemote(t *testing.T) {
func TestClearNeighborCacheOnNICDisable(t *testing.T) {
const (
- nicID = 1
-
- ipv4Addr = tcpip.Address("\x01\x02\x03\x04")
- ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04")
+ nicID = 1
linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
)
+ var (
+ ipv4Addr = testutil.MustParse4("1.2.3.4")
+ ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304")
+ )
+
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
new file mode 100644
index 000000000..ddff6e2d6
--- /dev/null
+++ b/pkg/tcpip/stack/tcp.go
@@ -0,0 +1,451 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPCubicState is used to hold a copy of the internal cubic state when the
+// TCPProbeFunc is invoked.
+//
+// +stateify savable
+type TCPCubicState struct {
+ // WLastMax is the previous wMax value.
+ WLastMax float64
+
+ // WMax is the value of the congestion window at the time of the last
+ // congestion event.
+ WMax float64
+
+ // T is the time when the current congestion avoidance was entered.
+ T time.Time `state:".(unixTime)"`
+
+ // TimeSinceLastCongestion denotes the time since the current
+ // congestion avoidance was entered.
+ TimeSinceLastCongestion time.Duration
+
+ // C is the cubic constant as specified in RFC8312, page 11.
+ C float64
+
+ // K is the time period (in seconds) that the above function takes to
+ // increase the current window size to WMax if there are no further
+ // congestion events and is calculated using the following equation:
+ //
+ // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5)
+ K float64
+
+ // Beta is the CUBIC multiplication decrease factor. That is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // WC(0)=WMax*beta_cubic.
+ Beta float64
+
+ // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's
+ // calculated using the formula:
+ //
+ // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1)
+ WC float64
+
+ // WEst is the window computed by CUBIC at time
+ // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT).
+ WEst float64
+}
+
+// TCPRACKState is used to hold a copy of the internal RACK state when the
+// TCPProbeFunc is invoked.
+//
+// +stateify savable
+type TCPRACKState struct {
+ // XmitTime is the transmission timestamp of the most recent
+ // acknowledged segment.
+ XmitTime time.Time `state:".(unixTime)"`
+
+ // EndSequence is the ending TCP sequence number of the most recent
+ // acknowledged segment.
+ EndSequence seqnum.Value
+
+ // FACK is the highest selectively or cumulatively acknowledged
+ // sequence.
+ FACK seqnum.Value
+
+ // RTT is the round trip time of the most recently delivered packet on
+ // the connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ RTT time.Duration
+
+ // Reord is true iff reordering has been detected on this connection.
+ Reord bool
+
+ // DSACKSeen is true iff the connection has seen a DSACK.
+ DSACKSeen bool
+
+ // ReoWnd is the reordering window time used for recording packet
+ // transmission times. It is used to defer the moment at which RACK
+ // marks a packet lost.
+ ReoWnd time.Duration
+
+ // ReoWndIncr is the multiplier applied to adjust reorder window.
+ ReoWndIncr uint8
+
+ // ReoWndPersist is the number of loss recoveries before resetting
+ // reorder window.
+ ReoWndPersist int8
+
+ // RTTSeq is the SND.NXT when RTT is updated.
+ RTTSeq seqnum.Value
+}
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+//
+// +stateify savable
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+//
+// +stateify savable
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery. The
+ // following fields are only meaningful when Active is true.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ // It exists to avoid attacks where the receiver intentionally sends
+ // duplicate acks to artificially inflate the sender's cwnd.
+ MaxCwnd int
+
+ // HighRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase. See: RFC 6675 Section 2 for
+ // details.
+ HighRxt seqnum.Value
+
+ // RescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission. See: RFC 6675 Section 2 for details.
+ RescueRxt seqnum.Value
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for a
+// given TCP endpoint.
+//
+// +stateify savable
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to its
+ // peer that it's willing to accept. This may be different than RcvNxt
+ // + (last advertised receive window) if the receive window is reduced;
+ // in that case we have to reduce the window as we receive more data
+ // instead of shrinking it.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive queue.
+ PendingBufUsed int
+}
+
+// TCPRTTState holds a copy of information about the endpoint's round trip
+// time.
+//
+// +stateify savable
+type TCPRTTState struct {
+ // SRTT is the smoothed round trip time defined in section 2 of RFC
+ // 6298.
+ SRTT time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates that a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for a given
+// TCP Endpoint.
+//
+// +stateify savable
+type TCPSenderState struct {
+ // LastSendTime is the timestamp at which we sent the last segment.
+ LastSendTime time.Time `state:".(unixTime)"`
+
+ // DupAckCount is the number of Duplicate ACKs received. It is used for
+ // fast retransmit.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the threshold between slow start and congestion
+ // avoidance.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets acknowledged during
+ // congestion avoidance. When enough packets have been ack'd (typically
+ // cwnd packets), the congestion window is incremented by one.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets that have been sent but not yet
+ // acknowledged.
+ Outstanding int
+
+ // SackedOut is the number of packets which have been selectively
+ // acked.
+ SackedOut int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest
+ // RTT measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time `state:".(unixTime)"`
+
+ // Closed indicates that the caller has closed the endpoint for
+ // sending.
+ Closed bool
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC
+ // 6298.
+ RTO time.Duration
+
+ // RTTState holds information about the endpoint's round trip time.
+ RTTState TCPRTTState
+
+ // MaxPayloadSize is the maximum size of the payload of a given
+ // segment. It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the
+ // send window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgement number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+
+ // Cubic holds the state related to CUBIC congestion control.
+ Cubic TCPCubicState
+
+ // RACKState holds the state related to RACK loss detection algorithm.
+ RACKState TCPRACKState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+//
+// +stateify savable
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK Blocks that identify the out of order
+ // segments held by a given TCP endpoint.
+ Blocks []header.SACKBlock
+
+ // ReceivedBlocks are the SACK blocks received by this endpoint from
+ // the peer endpoint.
+ ReceivedBlocks []header.SACKBlock
+
+ // MaxSACKED is the highest sequence number that has been SACKED by the
+ // peer.
+ MaxSACKED seqnum.Value
+}
+
+// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning.
+//
+// +stateify savable
+type RcvBufAutoTuneParams struct {
+ // MeasureTime is the time at which the current measurement was
+ // started.
+ MeasureTime time.Time `state:".(unixTime)"`
+
+ // CopiedBytes is the number of bytes copied to user space since this
+ // measure began.
+ CopiedBytes int
+
+ // PrevCopiedBytes is the number of bytes copied to userspace in the
+ // previous RTT period.
+ PrevCopiedBytes int
+
+ // RcvBufSize is the auto tuned receive buffer size.
+ RcvBufSize int
+
+ // RTT is the smoothed RTT as measured by observing the time between
+ // when a byte is first acknowledged and the receipt of data that is at
+ // least one window beyond the sequence number that was acknowledged.
+ RTT time.Duration
+
+ // RTTVar is the "round-trip time variation" as defined in section 2 of
+ // RFC6298.
+ RTTVar time.Duration
+
+ // RTTMeasureSeqNumber is the highest acceptable sequence number at the
+ // time this RTT measurement period began.
+ RTTMeasureSeqNumber seqnum.Value
+
+ // RTTMeasureTime is the absolute time at which the current RTT
+ // measurement period began.
+ RTTMeasureTime time.Time `state:".(unixTime)"`
+
+ // Disabled is true if an explicit receive buffer is set for the
+ // endpoint.
+ Disabled bool
+}
+
+// TCPRcvBufState contains information about the state of an endpoint's receive
+// socket buffer.
+//
+// +stateify savable
+type TCPRcvBufState struct {
+ // RcvBufUsed is the amount of bytes actually held in the receive
+ // socket buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvBufAutoTuneParams is used to hold state variables to compute the
+ // auto tuned receive buffer size.
+ RcvAutoParams RcvBufAutoTuneParams
+
+ // RcvClosed if true, indicates the endpoint has been closed for
+ // reading.
+ RcvClosed bool
+}
+
+// TCPSndBufState contains information about the state of an endpoint's send
+// socket buffer.
+//
+// +stateify savable
+type TCPSndBufState struct {
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+}
+
+// TCPEndpointStateInner contains the members of TCPEndpointState used directly
+// (that is, not within another containing struct) within the endpoint's
+// internal implementation.
+//
+// +stateify savable
+type TCPEndpointStateInner struct {
+ // TSOffset is a randomized offset added to the value of the TSVal
+ // field in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field
+ // is updated if required when a new segment is received by this
+ // endpoint.
+ RecentTS uint32
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+//
+// +stateify savable
+type TCPEndpointState struct {
+ // TCPEndpointStateInner contains the members of TCPEndpointState used
+ // by the endpoint's internal implementation.
+ TCPEndpointStateInner
+
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time `state:".(unixTime)"`
+
+ // RcvBufState contains information about the state of the endpoint's
+ // receive socket buffer.
+ RcvBufState TCPRcvBufState
+
+ // SndBufState contains information about the state of the endpoint's
+ // send socket buffer.
+ SndBufState TCPSndBufState
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // Receiver holds variables related to the TCP receiver for the
+ // endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e188efccb..80ad1a9d4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
return eps
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
+// handlePacket is called by the stack when new packets arrive to this transport
+// endpoint. It returns false if the packet could not be matched to any
+// transport endpoint, true otherwise.
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool {
epsByNIC.mu.RLock()
mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
- return
+ return false
}
}
@@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet
if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
- return
+ return true
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
- return
+ return true
}
transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return true
}
// handleError delivers an error to the transport endpoint identified by id.
@@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber,
}
return false
}
- ep.handlePacket(id, pkt)
- return true
+ return ep.handlePacket(id, pkt)
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 054cced0c..839178809 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
- ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return ep
}
@@ -106,7 +106,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
Data: buffer.View(v).ToVectorisedView(),
})
_ = pkt.TransportHeader().Push(fakeTransHeaderLen)
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
+ if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
return 0, err
}
@@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
peerAddr: route.RemoteAddress(),
route: route,
}
- ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits)
+ ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
f.acceptQueue = append(f.acceptQueue, ep)
}