diff options
Diffstat (limited to 'pkg/tcpip/stack')
24 files changed, 1705 insertions, 1106 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 49362333a..2bd6a67f5 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -45,6 +45,7 @@ go_library( "addressable_endpoint_state.go", "conntrack.go", "headertype_string.go", + "hook_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -66,6 +67,7 @@ go_library( "stack.go", "stack_global_state.go", "stack_options.go", + "tcp.go", "transport_demuxer.go", "tuple_list.go", ], @@ -115,6 +117,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -139,6 +142,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 3f083928f..5720e7543 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -16,6 +16,7 @@ package stack import ( "encoding/binary" + "fmt" "sync" "time" @@ -29,7 +30,7 @@ import ( // The connection is created for a packet if it does not exist. Every // connection contains two tuples (original and reply). The tuples are // manipulated if there is a matching NAT rule. The packet is modified by -// looking at the tuples in the Prerouting and Output hooks. +// looking at the tuples in each hook. // // Currently, only TCP tracking is supported. @@ -46,12 +47,14 @@ const ( ) // Manipulation type for the connection. +// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and +// DNAT at the same time. type manipType int const ( manipNone manipType = iota - manipDstPrerouting - manipDstOutput + manipSource + manipDestination ) // tuple holds a connection's identifying and manipulating data in one @@ -108,6 +111,7 @@ type conn struct { reply tuple // manip indicates if the packet should be manipulated. It is immutable. + // TODO(gvisor.dev/issue/5696): Support updating manipulation type. manip manipType // tcbHook indicates if the packet is inbound or outbound to @@ -124,6 +128,18 @@ type conn struct { lastUsed time.Time `state:".(unixTime)"` } +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn +} + // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour @@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { }, nil } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - func (ct *ConnTrack) init() { ct.mu.Lock() defer ct.mu.Unlock() @@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1 return nil } - // Create a new connection and change the port as per the iptables - // rule. This tuple will be used to manipulate the packet in - // handlePacket. replyTID := tid.reply() replyTID.srcAddr = address replyTID.srcPort = port - var manip manipType - switch hook { - case Prerouting: - manip = manipDstPrerouting - case Output: - manip = manipDstOutput + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil } - conn := newConn(tid, replyTID, manip, hook) + conn = newConn(tid, replyTID, manipDestination, hook) + ct.insertConn(conn) + return conn +} + +func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Input && hook != Postrouting { + return nil + } + + replyTID := tid.reply() + replyTID.dstAddr = address + replyTID.dstPort = port + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil + } + conn = newConn(tid, replyTID, manipSource, hook) ct.insertConn(conn) return conn } @@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Now that we hold the locks, ensure the tuple hasn't been inserted by // another thread. + // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? alreadyInserted := false for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { if other.tupleID == conn.original.tupleID { @@ -343,95 +369,17 @@ func (ct *ConnTrack) insertConn(conn *conn) { } } -// handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. -func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For prerouting redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. - switch dir { - case dirOriginal: - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - case dirReply: - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated - // on inbound packets, so we don't recalculate them. However, we should - // support cases when they are validated, e.g. when we can't offload - // receive checksumming. - - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - -// handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For output redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. For prerouting redirection, we only reach this point - // when replying, so packet sources are modified. - if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - } else { - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - if gso != nil && gso.NeedsCsum { - tcpHeader.SetChecksum(xsum) - } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) - } - - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - // handlePacket will manipulate the port and address of the packet if the // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if pkt.NatDone { return false } - if hook != Prerouting && hook != Output { + switch hook { + case Prerouting, Input, Output, Postrouting: + default: return false } @@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } conn, dir := ct.connFor(pkt) - // Connection or Rule not found for the packet. + // Connection not found for the packet. if conn == nil { - return true + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting } + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return false } + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + switch hook { + case Prerouting, Output: + if conn.manip == manipDestination { + switch dir { + case dirOriginal: + tcpHeader.SetDestinationPort(conn.reply.srcPort) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + case dirReply: + tcpHeader.SetSourcePort(conn.original.dstPort) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + pkt.NatDone = true + } + case Input, Postrouting: + if conn.manip == manipSource { + switch dir { + case dirOriginal: + tcpHeader.SetSourcePort(conn.reply.dstPort) + netHeader.SetSourceAddress(conn.reply.dstAddr) + case dirReply: + tcpHeader.SetDestinationPort(conn.original.srcPort) + netHeader.SetDestinationAddress(conn.original.srcAddr) + } + pkt.NatDone = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + if !pkt.NatDone { + return false + } + switch hook { - case Prerouting: - handlePacketPrerouting(pkt, conn, dir) - case Output: - handlePacketOutput(pkt, conn, gso, r, dir) + case Prerouting, Input: + case Output, Postrouting: + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + length := uint16(len(tcpHeader) + pkt.Data().Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) } - pkt.NatDone = true // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle @@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ if conn == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip == manipNone { - // Unmanipulated connection. + } else if conn.manip != manipDestination { + // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 16ee75bc4..7d3725681 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: vv.ToView().ToVectorisedView(), }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. + // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. _ = r.WriteHeaderIncludedPacket(pkt) } @@ -117,7 +117,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -125,11 +125,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress()[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -139,7 +139,7 @@ func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *Packet return &tcpip.ErrMalformedHeader{} } - return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -264,6 +264,8 @@ type fwdTestPacketInfo struct { Pkt *PacketBuffer } +var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) + type fwdTestLinkEndpoint struct { dispatcher NetworkDispatcher mtu uint32 @@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { return caps | CapabilityResolutionRequired } -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { @@ -322,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -338,10 +335,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, pkt) + e.WritePacket(r, protocol, pkt) n++ } diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go new file mode 100644 index 000000000..3dc8a7b02 --- /dev/null +++ b/pkg/tcpip/stack/hook_string.go @@ -0,0 +1,41 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type Hook ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Prerouting-0] + _ = x[Input-1] + _ = x[Forward-2] + _ = x[Output-3] + _ = x[Postrouting-4] + _ = x[NumHooks-5] +} + +const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks" + +var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47} + +func (i Hook) String() string { + if i >= Hook(len(_Hook_index)-1) { + return "Hook(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Hook_name[_Hook_index[i]:_Hook_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 52890f6eb..e2894c548 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -175,9 +175,10 @@ func DefaultTables() *IPTables { }, }, priorities: [NumHooks][]TableID{ - Prerouting: {MangleID, NATID}, - Input: {NATID, FilterID}, - Output: {MangleID, NATID, FilterID}, + Prerouting: {MangleID, NATID}, + Input: {NATID, FilterID}, + Output: {MangleID, NATID, FilterID}, + Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -266,12 +267,12 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from // which address can be gathered. Currently, address is only needed for // prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -285,7 +286,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer // Packets are manipulated only if connection and matching // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] @@ -302,7 +303,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -313,7 +314,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -385,10 +386,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { + if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +409,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +430,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,7 +456,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -478,7 +479,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) + return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 0e8b90c9b..2812c89aa 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -103,7 +103,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -174,7 +174,85 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // packet of the connection comes here. Other packets will be // manipulated in connection tracking. if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, gso, r) + ct.handlePacket(pkt, hook, r) + } + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} + +// SNATTarget modifies the source port/IP in the outgoing packets. +type SNATTarget struct { + Addr tcpip.Address + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } + + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + switch hook { + case Postrouting, Input: + case Prerouting, Output, Forward: + panic(fmt.Sprintf("%s not supported", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetChecksum(0) + udpHeader.SetSourcePort(st.Port) + netHeader := pkt.Network() + netHeader.SetSourceAddress(st.Addr) + + // Only calculate the checksum if offloading isn't supported. + if r.RequiresTXTransportChecksum() { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } + pkt.NatDone = true + case header.TCPProtocolNumber: + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index b0d84befb..4631ab93f 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -345,5 +345,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 14124ae66..c585b81b2 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -33,15 +33,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) +var ( + addr1 = testutil.MustParse6("a00::1") + addr2 = testutil.MustParse6("a00::2") + addr3 = testutil.MustParse6("a00::3") +) + const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") @@ -1142,57 +1146,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) +func TestDynamicConfigurationsDisabled(t *testing.T) { + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + prefix := tcpip.AddressWithPrefix{ + Address: testutil.MustParse6("102:304:506:708::"), + PrefixLen: 64, + } - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: + tests := []struct { + name string + config func(bool) ipv6.NDPConfigurations + ra *stack.PacketBuffer + }{ + { + name: "No Router Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} + }, + ra: raBuf(llAddr2, 1000), + }, + { + name: "No Prefix Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), + }, + { + name: "No Autogenerate Addresses", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Being configured to discover routers/prefixes or auto-generate + // addresses means RAs must be handled, and router/prefix discovery or + // SLAAC must be enabled. + // + // This tests all possible combinations of the configurations where + // router/prefix discovery or SLAAC are disabled. + for i := 0; i < 7; i++ { + handle := ipv6.HandlingRAsDisabled + if i&1 != 0 { + handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled + } + enable := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + ndpConfigs := test.config(enable) + ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, + }) + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding + ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) + } + stats := ep.Stats() + v6Stats, ok := stats.(*ipv6.Stats) + if !ok { + t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) + } + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) + } + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) + } + + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) + default: + } + }) } }) } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) } +func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { + tests := [...]struct { + name string + handleRAs ipv6.HandleRAsConfiguration + forwarding bool + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding disabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding enabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f(t, test.handleRAs, test.forwarding) + }) + } +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1203,7 +1348,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1237,103 +1382,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") } - default: - t.Fatal("expected router discovery event") } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + }) } // TestRouterDiscoveryMaxRouters tests that only @@ -1347,7 +1498,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1386,57 +1537,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for prefix on nic with ID 1, and the // discovered flag set to discovered. func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { @@ -1455,8 +1555,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1494,87 +1593,93 @@ func TestPrefixDiscovery(t *testing.T) { prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") } - default: - t.Fatal("expected prefix discovery event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) + }) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1590,7 +1695,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }() prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, } subnet := prefix.Subnet() @@ -1603,7 +1708,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1688,7 +1793,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: false, DiscoverOnLinkPrefixes: true, }, @@ -1753,53 +1858,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) return containsAddr(list, protocolAddress) } -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for addr on nic with ID 1, and the // event type is set to eventType. func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { @@ -1808,7 +1866,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr2(t *testing.T) { +func TestAutoGenAddr(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second saved := ipv6.MinPrefixInformationValidLifetimeForUpdate @@ -1820,96 +1878,102 @@ func TestAutoGenAddr2(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - default: - t.Fatal("expected addr auto gen event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + }) } func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { @@ -1997,7 +2061,7 @@ func TestAutoGenTempAddr(t *testing.T) { RetransmitTimer: test.retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2298,7 +2362,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2385,7 +2449,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2534,7 +2598,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2735,7 +2799,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, @@ -2880,7 +2944,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: ndpDisp, @@ -3347,7 +3411,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3490,7 +3554,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3557,7 +3621,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3723,7 +3787,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3805,7 +3869,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3969,7 +4033,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { @@ -3996,7 +4060,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4146,7 +4210,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4274,7 +4338,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4480,7 +4544,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4531,7 +4595,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4625,8 +4689,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } } -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. +func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { + const ( + lifetimeSeconds = 999 + nicID = 1 + ) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenLinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) + + e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) + if err := s.CreateNIC(nicID, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) + } + + prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) + e1.InjectInbound( + header.IPv6ProtocolNumber, + raBufWithPI( + llAddr3, + lifetimeSeconds, + prefix, + true, /* onLink */ + true, /* auto */ + lifetimeSeconds, + lifetimeSeconds, + ), + ) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + } + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) + } + + // Enabling or disabling forwarding should not invalidate discovered prefixes + // or routers, or auto-generated address. + for _, forwarding := range [...]bool{true, false} { + t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpected router event = %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpected prefix event = %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpected auto-gen addr event = %#v", e) + default: + } + }) + } +} + func TestCleanupNDPState(t *testing.T) { const ( lifetimeSeconds = 5 @@ -4655,18 +4821,6 @@ func TestCleanupNDPState(t *testing.T) { maxAutoGenAddrEvents int skipFinalAddrCheck bool }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - // A NIC should cleanup all NDP state when it is disabled. { name: "Disable NIC", @@ -4718,7 +4872,7 @@ func TestCleanupNDPState(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, DiscoverOnLinkPrefixes: true, AutoGenGlobalAddresses: true, @@ -4991,7 +5145,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -5182,96 +5336,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5362,13 +5547,14 @@ func TestStopStartSolicitingRouters(t *testing.T) { } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 48bb75e2f..9821a18d3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} + clock := tcpip.NewStdClock() linkRes := newTestNeighborResolver(nil, config, clock) linkRes.delay = 0 diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index bb2b2d705..1d39ee73d 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -26,14 +26,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 entryTestNICID tcpip.NICID = 1 - entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") @@ -44,6 +43,11 @@ const ( entryTestNetDefaultMTU = 65536 ) +var ( + entryTestAddr1 = testutil.MustParse6("a::1") + entryTestAddr2 = testutil.MustParse6("a::2") +) + // runImmediatelyScheduledJobs runs all jobs scheduled to run at the current // time. func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca15c0691..8d615500f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -316,30 +316,30 @@ func (n *nic) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) +func (n *nic) WritePacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { + _, err := n.enqueuePacketBuffer(r, protocol, pkt) return err } -func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) writePacketBuffer(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { switch pkt := pkt.(type) { case *PacketBuffer: - if err := n.writePacket(r, gso, protocol, pkt); err != nil { + if err := n.writePacket(r, protocol, pkt); err != nil { return 0, err } return 1, nil case *PacketBufferList: - return n.writePackets(r, gso, protocol, *pkt) + return n.writePackets(r, protocol, *pkt) default: panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } } -func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) enqueuePacketBuffer(r *Route, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { routeInfo, _, err := r.resolvedFields(nil) switch err.(type) { case nil: - return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + return n.writePacketBuffer(routeInfo, protocol, pkt) case *tcpip.ErrWouldBlock: // As per relevant RFCs, we should queue packets while we wait for link // resolution to complete. @@ -358,28 +358,27 @@ func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt // SHOULD be limited to some small value. When a queue overflows, the new // arrival SHOULD replace the oldest entry. Once address resolution // completes, the node transmits any queued packets. - return n.linkResQueue.enqueue(r, gso, protocol, pkt) + return n.linkResQueue.enqueue(r, protocol, pkt) default: return 0, err } } // WritePacketToRemote implements NetworkInterface. -func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr - return n.writePacket(r, gso, protocol, pkt) + return n.writePacket(r, protocol, pkt) } -func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() pkt.EgressRoute = r - pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol - if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { + if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil { return err } @@ -389,18 +388,17 @@ func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +func (n *nic) WritePackets(r *Route, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return n.enqueuePacketBuffer(r, protocol, &pkts) } -func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { +func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r - pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol } - writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) + writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index c0f956e53..8a3005295 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -65,12 +65,12 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. return 0, &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 8f288675d..9527416cf 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -103,7 +103,7 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. EgressRoute RouteInfo - GSOOptions *GSO + GSOOptions GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. @@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - return NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), }) + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + if pk.NatDone { + newPk.NatDone = true + } + return newPk } // headerInfo stores metadata about a header in a packet. @@ -355,9 +364,10 @@ func (d PacketData) PullUp(size int) (buffer.View, bool) { return d.pk.data.PullUp(size) } -// TrimFront removes count from the beginning of d. It panics if count > -// d.Size(). -func (d PacketData) TrimFront(count int) { +// DeleteFront removes count from the beginning of d. It panics if count > +// d.Size(). All backing storage references after the front of the d are +// invalidated. +func (d PacketData) DeleteFront(count int) { d.pk.data.TrimFront(count) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 6728370c3..bd4eb4fed 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkData(t, pk, test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) + // Check the after state. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.link, + network: test.network, + transport: test.transport, + data: test.data, + }) }) } } @@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) { if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkData(t, pk, payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) + // Check the after state of pk. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.data[:test.link], + network: test.data[test.link:][:test.network], + transport: test.data[test.link+test.network:][:test.transport], + data: test.data[allHdrSize:], + }) }) } } @@ -252,6 +226,39 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) { }) } +// This is a very obscure use-case seen in the code that verifies packets +// before sending them out. It tries to parse the headers to verify. +// PacketHeader was initially not designed to mix Push() and Consume(), but it +// works and it's been relied upon. Include a test here. +func TestPacketHeaderPushConsumeMixed(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := append([]byte(nil), network...) + initData = append(initData, data...) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Consume network header + gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) + if !ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) + } + checkViewEqual(t, "gotNetwork", gotNetwork, network) + + // 2. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + network: network, + data: data, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 @@ -397,11 +404,11 @@ func TestPacketBufferData(t *testing.T) { } }) - // TrimFront + // DeleteFront for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().TrimFront(n) + pkt.Data().DeleteFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) @@ -494,6 +501,37 @@ func TestPacketBufferData(t *testing.T) { } } +type packetContents struct { + link buffer.View + network buffer.View + transport buffer.View + data buffer.View +} + +func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { + t.Helper() + // Headers. + checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) + checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) + checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) + // Data. + checkData(t, pk, want.data) + // Whole packet. + checkViewEqual(t, prefix+"pk.Views()", + concatViews(pk.Views()...), + concatViews(want.link, want.network, want.transport, want.data)) + // PayloadSince. + checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(want.link, want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(want.transport, want.data)) +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -510,19 +548,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkData(t, pk, data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) + checkPacketContents(t, "Initial ", pk, packetContents{ + data: data, + }) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index e936aa728..13e8907ec 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -46,7 +46,6 @@ func (p *PacketBufferList) len() int { type pendingPacket struct { routeInfo RouteInfo - gso *GSO proto tcpip.NetworkProtocolNumber pkt pendingPacketBuffer } @@ -119,7 +118,7 @@ func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpi // If the maximum number of pending resolutions is reached, the packets // associated with the oldest link resolution will be dequeued as if they failed // link resolution. -func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { f.mu.Lock() // Make sure we attempt resolution while holding f's lock so that we avoid // a race where link resolution completes before we enqueue the packets. @@ -137,7 +136,7 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N // The route resolved immediately, so we don't need to wait for link // resolution to send the packet. f.mu.Unlock() - return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + return f.nic.writePacketBuffer(routeInfo, proto, pkt) case *tcpip.ErrWouldBlock: // We need to wait for link resolution to complete. default: @@ -150,7 +149,6 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N packets, ok := f.mu.packets[ch] packets = append(packets, pendingPacket{ routeInfo: routeInfo, - gso: gso, proto: proto, pkt: pkt, }) @@ -211,7 +209,7 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l for _, p := range packets { if err == nil { p.routeInfo.RemoteLinkAddress = linkAddr - _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.proto, p.pkt) } else { f.incrementOutgoingPacketErrors(p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ff3a385e1..e26225552 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -537,14 +537,14 @@ type NetworkInterface interface { CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -554,7 +554,7 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + WritePackets(*Route, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP // request or NDP Neighbor Solicitation). @@ -610,12 +610,12 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error + WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) + WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. @@ -756,11 +756,6 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityHardwareGSO - - // CapabilitySoftwareGSO indicates the link endpoint supports of sending - // multiple packets using a single call (LinkEndpoint.WritePackets). - CapabilitySoftwareGSO ) // NetworkLinkEndpoint is a data-link layer that supports sending network @@ -832,7 +827,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacket(RouteInfo, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -842,7 +837,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -1047,10 +1042,29 @@ type GSO struct { MaxSize uint32 } +// SupportedGSO returns the type of segmentation offloading supported. +type SupportedGSO int + +const ( + // GSONotSupported indicates that segmentation offloading is not supported. + GSONotSupported SupportedGSO = iota + + // HWGSOSupported indicates that segmentation offloading may be performed by + // the hardware. + HWGSOSupported + + // SWGSOSupported indicates that segmentation offloading may be performed in + // software. + SWGSOSupported +) + // GSOEndpoint provides access to GSO properties. type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 + + // SupportedGSO returns the supported segmentation offloading. + SupportedGSO() SupportedGSO } // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 39344808d..8a044c073 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp localAddr = addressEndpoint.AddressWithPrefix().Address } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) { addressEndpoint.DecRef() return nil } @@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool { // HasSoftwareGSOCapability returns true if the route supports software GSO. func (r *Route) HasSoftwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == SWGSOSupported + } + return false } // HasHardwareGSOCapability returns true if the route supports hardware GSO. func (r *Route) HasHardwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == HWGSOSupported + } + return false } // HasSaveRestoreCapability returns true if the route supports save/restore. @@ -448,22 +454,22 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { +func (r *Route) WritePacket(params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { return &tcpip.ErrInvalidEndpointState{} } - return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (r *Route) WritePackets(pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { return 0, &tcpip.ErrInvalidEndpointState{} } - return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 931a97ddc..436392f23 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,306 +55,6 @@ type transportProtocolState struct { defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } -// TCPProbeFunc is the expected function type for a TCP probe function to be -// passed to stack.AddTCPProbe. -type TCPProbeFunc func(s TCPEndpointState) - -// TCPCubicState is used to hold a copy of the internal cubic state when the -// TCPProbeFunc is invoked. -type TCPCubicState struct { - WLastMax float64 - WMax float64 - T time.Time - TimeSinceLastCongestion time.Duration - C float64 - K float64 - Beta float64 - WC float64 - WEst float64 -} - -// TCPRACKState is used to hold a copy of the internal RACK state when the -// TCPProbeFunc is invoked. -type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool - ReoWnd time.Duration - ReoWndIncr uint8 - ReoWndPersist int8 - RTTSeq seqnum.Value -} - -// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. -type TCPEndpointID struct { - // LocalPort is the local port associated with the endpoint. - LocalPort uint16 - - // LocalAddress is the local [network layer] address associated with - // the endpoint. - LocalAddress tcpip.Address - - // RemotePort is the remote port associated with the endpoint. - RemotePort uint16 - - // RemoteAddress it the remote [network layer] address associated with - // the endpoint. - RemoteAddress tcpip.Address -} - -// TCPFastRecoveryState holds a copy of the internal fast recovery state of a -// TCP endpoint. -type TCPFastRecoveryState struct { - // Active if true indicates the endpoint is in fast recovery. - Active bool - - // First is the first unacknowledged sequence number being recovered. - First seqnum.Value - - // Last is the 'recover' sequence number that indicates the point at - // which we should exit recovery barring any timeouts etc. - Last seqnum.Value - - // MaxCwnd is the maximum value we are permitted to grow the congestion - // window during recovery. This is set at the time we enter recovery. - MaxCwnd int - - // HighRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - HighRxt seqnum.Value - - // RescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - RescueRxt seqnum.Value -} - -// TCPReceiverState holds a copy of the internal state of the receiver for -// a given TCP endpoint. -type TCPReceiverState struct { - // RcvNxt is the TCP variable RCV.NXT. - RcvNxt seqnum.Value - - // RcvAcc is the TCP variable RCV.ACC. - RcvAcc seqnum.Value - - // RcvWndScale is the window scaling to use for inbound segments. - RcvWndScale uint8 - - // PendingBufUsed is the number of bytes pending in the receive - // queue. - PendingBufUsed int -} - -// TCPSenderState holds a copy of the internal state of the sender for -// a given TCP Endpoint. -type TCPSenderState struct { - // LastSendTime is the time at which we sent the last segment. - LastSendTime time.Time - - // DupAckCount is the number of Duplicate ACK's received. - DupAckCount int - - // SndCwnd is the size of the sending congestion window in packets. - SndCwnd int - - // Ssthresh is the slow start threshold in packets. - Ssthresh int - - // SndCAAckCount is the number of packets consumed in congestion - // avoidance mode. - SndCAAckCount int - - // Outstanding is the number of packets in flight. - Outstanding int - - // SackedOut is the number of packets which have been selectively acked. - SackedOut int - - // SndWnd is the send window size in bytes. - SndWnd seqnum.Size - - // SndUna is the next unacknowledged sequence number. - SndUna seqnum.Value - - // SndNxt is the sequence number of the next segment to be sent. - SndNxt seqnum.Value - - // RTTMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - RTTMeasureSeqNum seqnum.Value - - // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time - - // Closed indicates that the caller has closed the endpoint for sending. - Closed bool - - // SRTT is the smoothed round-trip time as defined in section 2 of - // RFC 6298. - SRTT time.Duration - - // RTO is the retransmit timeout as defined in section of 2 of RFC 6298. - RTO time.Duration - - // RTTVar is the round-trip time variation as defined in section 2 of - // RFC 6298. - RTTVar time.Duration - - // SRTTInited if true indicates take a valid RTT measurement has been - // completed. - SRTTInited bool - - // MaxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - MaxPayloadSize int - - // SndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - SndWndScale uint8 - - // MaxSentAck is the highest acknowledgement number sent till now. - MaxSentAck seqnum.Value - - // FastRecovery holds the fast recovery state for the endpoint. - FastRecovery TCPFastRecoveryState - - // Cubic holds the state related to CUBIC congestion control. - Cubic TCPCubicState - - // RACKState holds the state related to RACK loss detection algorithm. - RACKState TCPRACKState -} - -// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. -type TCPSACKInfo struct { - // Blocks is the list of SACK Blocks that identify the out of order segments - // held by a given TCP endpoint. - Blocks []header.SACKBlock - - // ReceivedBlocks are the SACK blocks received by this endpoint - // from the peer endpoint. - ReceivedBlocks []header.SACKBlock - - // MaxSACKED is the highest sequence number that has been SACKED - // by the peer. - MaxSACKED seqnum.Value -} - -// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. -type RcvBufAutoTuneParams struct { - // MeasureTime is the time at which the current measurement - // was started. - MeasureTime time.Time - - // CopiedBytes is the number of bytes copied to user space since - // this measure began. - CopiedBytes int - - // PrevCopiedBytes is the number of bytes copied to userspace in - // the previous RTT period. - PrevCopiedBytes int - - // RcvBufSize is the auto tuned receive buffer size. - RcvBufSize int - - // RTT is the smoothed RTT as measured by observing the time between - // when a byte is first acknowledged and the receipt of data that is at - // least one window beyond the sequence number that was acknowledged. - RTT time.Duration - - // RTTVar is the "round-trip time variation" as defined in section 2 - // of RFC6298. - RTTVar time.Duration - - // RTTMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - RTTMeasureSeqNumber seqnum.Value - - // RTTMeasureTime is the absolute time at which the current RTT - // measurement period began. - RTTMeasureTime time.Time - - // Disabled is true if an explicit receive buffer is set for the - // endpoint. - Disabled bool -} - -// TCPEndpointState is a copy of the internal state of a TCP endpoint. -type TCPEndpointState struct { - // ID is a copy of the TransportEndpointID for the endpoint. - ID TCPEndpointID - - // SegTime denotes the absolute time when this segment was received. - SegTime time.Time - - // RcvBufSize is the size of the receive socket buffer for the endpoint. - RcvBufSize int - - // RcvBufUsed is the amount of bytes actually held in the receive socket - // buffer for the endpoint. - RcvBufUsed int - - // RcvBufAutoTuneParams is used to hold state variables to compute - // the auto tuned receive buffer size. - RcvAutoParams RcvBufAutoTuneParams - - // RcvClosed if true, indicates the endpoint has been closed for reading. - RcvClosed bool - - // SendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1. - SendTSOk bool - - // RecentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - RecentTS uint32 - - // TSOffset is a randomized offset added to the value of the TSVal field - // in the timestamp option. - TSOffset uint32 - - // SACKPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - SACKPermitted bool - - // SACK holds TCP SACK related information for this endpoint. - SACK TCPSACKInfo - - // SndBufSize is the size of the socket send buffer. - SndBufSize int - - // SndBufUsed is the number of bytes held in the socket send buffer. - SndBufUsed int - - // SndClosed indicates that the endpoint has been closed for sends. - SndClosed bool - - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - - // PacketTooBigCount is used to notify the main protocol routine how - // many times a "packet too big" control packet is received. - PacketTooBigCount int - - // SndMTU is the smallest MTU seen in the control packets received. - SndMTU int - - // Receiver holds variables related to the TCP receiver for the endpoint. - Receiver TCPReceiverState - - // Sender holds state related to the TCP Sender for the endpoint. - Sender TCPSenderState -} - // ResumableEndpoint is an endpoint that needs to be resumed after restore. type ResumableEndpoint interface { // Resume resumes an endpoint after restore. This can be used to restart @@ -455,7 +154,7 @@ type Stack struct { // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. - receiveBufferSize ReceiveBufferSizeOption + receiveBufferSize tcpip.ReceiveBufferSizeOption // tcpInvalidRateLimit is the maximal rate for sending duplicate // acknowledgements in response to incoming TCP packets that are for an existing @@ -623,7 +322,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {} func New(opts Options) *Stack { clock := opts.Clock if clock == nil { - clock = &tcpip.StdClock{} + clock = tcpip.NewStdClock() } if opts.UniqueID == nil { @@ -669,7 +368,7 @@ func New(opts Options) *Stack { Default: DefaultBufferSize, Max: DefaultMaxBufferSize, }, - receiveBufferSize: ReceiveBufferSizeOption{ + receiveBufferSize: tcpip.ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) + isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) @@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1874,7 +1573,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) - return nic.WritePacketToRemote(remote, nil, netProto, pkt) + return nic.WritePacketToRemote(remote, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index dfec4258a..33824afd0 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,6 +14,78 @@ package stack +import "time" + // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack + +// saveT is invoked by stateify. +func (t *TCPCubicState) saveT() unixTime { + return unixTime{t.T.Unix(), t.T.UnixNano()} +} + +// loadT is invoked by stateify. +func (t *TCPCubicState) loadT(unix unixTime) { + t.T = time.Unix(unix.second, unix.nano) +} + +// saveXmitTime is invoked by stateify. +func (t *TCPRACKState) saveXmitTime() unixTime { + return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (t *TCPRACKState) loadXmitTime(unix unixTime) { + t.XmitTime = time.Unix(unix.second, unix.nano) +} + +// saveLastSendTime is invoked by stateify. +func (t *TCPSenderState) saveLastSendTime() unixTime { + return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} +} + +// loadLastSendTime is invoked by stateify. +func (t *TCPSenderState) loadLastSendTime(unix unixTime) { + t.LastSendTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) saveRTTMeasureTime() unixTime { + return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { + t.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { + return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} +} + +// loadMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { + r.MeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { + return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { + r.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveSegTime is invoked by stateify. +func (t *TCPEndpointState) saveSegTime() unixTime { + return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} +} + +// loadSegTime is invoked by stateify. +func (t *TCPEndpointState) loadSegTime(unix unixTime) { + t.SegTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 3066f4ffd..80e8e0089 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error { s.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { @@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error { s.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.ReceiveBufferSizeOption: s.mu.RLock() *v = s.receiveBufferSize s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 2814b94b4..d2c40cc43 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -137,11 +138,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data().TrimFront(fakeNetHeaderLen) + // DeleteFront invalidates slices. Make a copy before trimming. + nb := append([]byte(nil), hdr...) + pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -170,7 +173,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++ @@ -189,11 +192,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) + return f.nic.WritePacket(r, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -436,7 +439,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error } func send(r *stack.Route, payload buffer.View) tcpip.Error { - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), })) @@ -1461,7 +1464,7 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if n := ep.Drain(); n != 0 { t.Fatalf("got ep.Drain() = %d, want = 0", n) } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS, @@ -1645,10 +1648,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. s := stack.New(stack.Options{ @@ -2789,25 +2792,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") - ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") - toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 lifetimeSeconds = 9999 ) + var ( + linkLocalAddr1 = testutil.MustParse6("fe80::1") + linkLocalAddr2 = testutil.MustParse6("fe80::2") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr1 = testutil.MustParse6("a000::1") + globalAddr2 = testutil.MustParse6("a000::2") + globalAddr3 = testutil.MustParse6("a000::3") + ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1") + ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2") + toredoAddr1 = testutil.MustParse6("2001::1") + toredoAddr2 = testutil.MustParse6("2001::2") + ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1") + ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2") + ) + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) @@ -3017,7 +3022,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -3354,21 +3359,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - rs stack.ReceiveBufferSizeOption + rs tcpip.ReceiveBufferSizeOption err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -3377,7 +3382,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { if err := s.SetOption(tc.rs); err != tc.err { t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if tc.err == nil { if err := s.Option(&rs); err != nil { t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) @@ -3448,7 +3453,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, @@ -4352,13 +4357,15 @@ func TestWritePacketToRemote(t *testing.T) { func TestClearNeighborCacheOnNICDisable(t *testing.T) { const ( - nicID = 1 - - ipv4Addr = tcpip.Address("\x01\x02\x03\x04") - ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") + nicID = 1 linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + var ( + ipv4Addr = testutil.MustParse4("1.2.3.4") + ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304") + ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go new file mode 100644 index 000000000..ddff6e2d6 --- /dev/null +++ b/pkg/tcpip/stack/tcp.go @@ -0,0 +1,451 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// TCPProbeFunc is the expected function type for a TCP probe function to be +// passed to stack.AddTCPProbe. +type TCPProbeFunc func(s TCPEndpointState) + +// TCPCubicState is used to hold a copy of the internal cubic state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPCubicState struct { + // WLastMax is the previous wMax value. + WLastMax float64 + + // WMax is the value of the congestion window at the time of the last + // congestion event. + WMax float64 + + // T is the time when the current congestion avoidance was entered. + T time.Time `state:".(unixTime)"` + + // TimeSinceLastCongestion denotes the time since the current + // congestion avoidance was entered. + TimeSinceLastCongestion time.Duration + + // C is the cubic constant as specified in RFC8312, page 11. + C float64 + + // K is the time period (in seconds) that the above function takes to + // increase the current window size to WMax if there are no further + // congestion events and is calculated using the following equation: + // + // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5) + K float64 + + // Beta is the CUBIC multiplication decrease factor. That is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // WC(0)=WMax*beta_cubic. + Beta float64 + + // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's + // calculated using the formula: + // + // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1) + WC float64 + + // WEst is the window computed by CUBIC at time + // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT). + WEst float64 +} + +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPRACKState struct { + // XmitTime is the transmission timestamp of the most recent + // acknowledged segment. + XmitTime time.Time `state:".(unixTime)"` + + // EndSequence is the ending TCP sequence number of the most recent + // acknowledged segment. + EndSequence seqnum.Value + + // FACK is the highest selectively or cumulatively acknowledged + // sequence. + FACK seqnum.Value + + // RTT is the round trip time of the most recently delivered packet on + // the connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + RTT time.Duration + + // Reord is true iff reordering has been detected on this connection. + Reord bool + + // DSACKSeen is true iff the connection has seen a DSACK. + DSACKSeen bool + + // ReoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + ReoWnd time.Duration + + // ReoWndIncr is the multiplier applied to adjust reorder window. + ReoWndIncr uint8 + + // ReoWndPersist is the number of loss recoveries before resetting + // reorder window. + ReoWndPersist int8 + + // RTTSeq is the SND.NXT when RTT is updated. + RTTSeq seqnum.Value +} + +// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. +// +// +stateify savable +type TCPEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// TCPFastRecoveryState holds a copy of the internal fast recovery state of a +// TCP endpoint. +// +// +stateify savable +type TCPFastRecoveryState struct { + // Active if true indicates the endpoint is in fast recovery. The + // following fields are only meaningful when Active is true. + Active bool + + // First is the first unacknowledged sequence number being recovered. + First seqnum.Value + + // Last is the 'recover' sequence number that indicates the point at + // which we should exit recovery barring any timeouts etc. + Last seqnum.Value + + // MaxCwnd is the maximum value we are permitted to grow the congestion + // window during recovery. This is set at the time we enter recovery. + // It exists to avoid attacks where the receiver intentionally sends + // duplicate acks to artificially inflate the sender's cwnd. + MaxCwnd int + + // HighRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. See: RFC 6675 Section 2 for + // details. + HighRxt seqnum.Value + + // RescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. See: RFC 6675 Section 2 for details. + RescueRxt seqnum.Value +} + +// TCPReceiverState holds a copy of the internal state of the receiver for a +// given TCP endpoint. +// +// +stateify savable +type TCPReceiverState struct { + // RcvNxt is the TCP variable RCV.NXT. + RcvNxt seqnum.Value + + // RcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to its + // peer that it's willing to accept. This may be different than RcvNxt + // + (last advertised receive window) if the receive window is reduced; + // in that case we have to reduce the window as we receive more data + // instead of shrinking it. + RcvAcc seqnum.Value + + // RcvWndScale is the window scaling to use for inbound segments. + RcvWndScale uint8 + + // PendingBufUsed is the number of bytes pending in the receive queue. + PendingBufUsed int +} + +// TCPRTTState holds a copy of information about the endpoint's round trip +// time. +// +// +stateify savable +type TCPRTTState struct { + // SRTT is the smoothed round trip time defined in section 2 of RFC + // 6298. + SRTT time.Duration + + // RTTVar is the round-trip time variation as defined in section 2 of + // RFC 6298. + RTTVar time.Duration + + // SRTTInited if true indicates that a valid RTT measurement has been + // completed. + SRTTInited bool +} + +// TCPSenderState holds a copy of the internal state of the sender for a given +// TCP Endpoint. +// +// +stateify savable +type TCPSenderState struct { + // LastSendTime is the timestamp at which we sent the last segment. + LastSendTime time.Time `state:".(unixTime)"` + + // DupAckCount is the number of Duplicate ACKs received. It is used for + // fast retransmit. + DupAckCount int + + // SndCwnd is the size of the sending congestion window in packets. + SndCwnd int + + // Ssthresh is the threshold between slow start and congestion + // avoidance. + Ssthresh int + + // SndCAAckCount is the number of packets acknowledged during + // congestion avoidance. When enough packets have been ack'd (typically + // cwnd packets), the congestion window is incremented by one. + SndCAAckCount int + + // Outstanding is the number of packets that have been sent but not yet + // acknowledged. + Outstanding int + + // SackedOut is the number of packets which have been selectively + // acked. + SackedOut int + + // SndWnd is the send window size in bytes. + SndWnd seqnum.Size + + // SndUna is the next unacknowledged sequence number. + SndUna seqnum.Value + + // SndNxt is the sequence number of the next segment to be sent. + SndNxt seqnum.Value + + // RTTMeasureSeqNum is the sequence number being used for the latest + // RTT measurement. + RTTMeasureSeqNum seqnum.Value + + // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Closed indicates that the caller has closed the endpoint for + // sending. + Closed bool + + // RTO is the retransmit timeout as defined in section of 2 of RFC + // 6298. + RTO time.Duration + + // RTTState holds information about the endpoint's round trip time. + RTTState TCPRTTState + + // MaxPayloadSize is the maximum size of the payload of a given + // segment. It is initialized on demand. + MaxPayloadSize int + + // SndWndScale is the number of bits to shift left when reading the + // send window size from a segment. + SndWndScale uint8 + + // MaxSentAck is the highest acknowledgement number sent till now. + MaxSentAck seqnum.Value + + // FastRecovery holds the fast recovery state for the endpoint. + FastRecovery TCPFastRecoveryState + + // Cubic holds the state related to CUBIC congestion control. + Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState +} + +// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. +// +// +stateify savable +type TCPSACKInfo struct { + // Blocks is the list of SACK Blocks that identify the out of order + // segments held by a given TCP endpoint. + Blocks []header.SACKBlock + + // ReceivedBlocks are the SACK blocks received by this endpoint from + // the peer endpoint. + ReceivedBlocks []header.SACKBlock + + // MaxSACKED is the highest sequence number that has been SACKED by the + // peer. + MaxSACKED seqnum.Value +} + +// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. +// +// +stateify savable +type RcvBufAutoTuneParams struct { + // MeasureTime is the time at which the current measurement was + // started. + MeasureTime time.Time `state:".(unixTime)"` + + // CopiedBytes is the number of bytes copied to user space since this + // measure began. + CopiedBytes int + + // PrevCopiedBytes is the number of bytes copied to userspace in the + // previous RTT period. + PrevCopiedBytes int + + // RcvBufSize is the auto tuned receive buffer size. + RcvBufSize int + + // RTT is the smoothed RTT as measured by observing the time between + // when a byte is first acknowledged and the receipt of data that is at + // least one window beyond the sequence number that was acknowledged. + RTT time.Duration + + // RTTVar is the "round-trip time variation" as defined in section 2 of + // RFC6298. + RTTVar time.Duration + + // RTTMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + RTTMeasureSeqNumber seqnum.Value + + // RTTMeasureTime is the absolute time at which the current RTT + // measurement period began. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Disabled is true if an explicit receive buffer is set for the + // endpoint. + Disabled bool +} + +// TCPRcvBufState contains information about the state of an endpoint's receive +// socket buffer. +// +// +stateify savable +type TCPRcvBufState struct { + // RcvBufUsed is the amount of bytes actually held in the receive + // socket buffer for the endpoint. + RcvBufUsed int + + // RcvBufAutoTuneParams is used to hold state variables to compute the + // auto tuned receive buffer size. + RcvAutoParams RcvBufAutoTuneParams + + // RcvClosed if true, indicates the endpoint has been closed for + // reading. + RcvClosed bool +} + +// TCPSndBufState contains information about the state of an endpoint's send +// socket buffer. +// +// +stateify savable +type TCPSndBufState struct { + // SndBufSize is the size of the socket send buffer. + SndBufSize int + + // SndBufUsed is the number of bytes held in the socket send buffer. + SndBufUsed int + + // SndClosed indicates that the endpoint has been closed for sends. + SndClosed bool + + // SndBufInQueue is the number of bytes in the send queue. + SndBufInQueue seqnum.Size + + // PacketTooBigCount is used to notify the main protocol routine how + // many times a "packet too big" control packet is received. + PacketTooBigCount int + + // SndMTU is the smallest MTU seen in the control packets received. + SndMTU int +} + +// TCPEndpointStateInner contains the members of TCPEndpointState used directly +// (that is, not within another containing struct) within the endpoint's +// internal implementation. +// +// +stateify savable +type TCPEndpointStateInner struct { + // TSOffset is a randomized offset added to the value of the TSVal + // field in the timestamp option. + TSOffset uint32 + + // SACKPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + SACKPermitted bool + + // SendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1. + SendTSOk bool + + // RecentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field + // is updated if required when a new segment is received by this + // endpoint. + RecentTS uint32 +} + +// TCPEndpointState is a copy of the internal state of a TCP endpoint. +// +// +stateify savable +type TCPEndpointState struct { + // TCPEndpointStateInner contains the members of TCPEndpointState used + // by the endpoint's internal implementation. + TCPEndpointStateInner + + // ID is a copy of the TransportEndpointID for the endpoint. + ID TCPEndpointID + + // SegTime denotes the absolute time when this segment was received. + SegTime time.Time `state:".(unixTime)"` + + // RcvBufState contains information about the state of the endpoint's + // receive socket buffer. + RcvBufState TCPRcvBufState + + // SndBufState contains information about the state of the endpoint's + // send socket buffer. + SndBufState TCPSndBufState + + // SACK holds TCP SACK related information for this endpoint. + SACK TCPSACKInfo + + // Receiver holds variables related to the TCP receiver for the + // endpoint. + Receiver TCPReceiverState + + // Sender holds state related to the TCP Sender for the endpoint. + Sender TCPSenderState +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e188efccb..80ad1a9d4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { return eps } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { +// handlePacket is called by the stack when new packets arrive to this transport +// endpoint. It returns false if the packet could not be matched to any +// transport endpoint, true otherwise. +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return false } } @@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return true } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() - return + return true } transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. + return true } // handleError delivers an error to the transport endpoint identified by id. @@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, } return false } - ep.handlePacket(id, pkt) - return true + return ep.handlePacket(id, pkt) } // deliverRawPacket attempts to deliver the given packet and returns whether it diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 054cced0c..839178809 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) return ep } @@ -106,7 +106,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions Data: buffer.View(v).ToVectorisedView(), }) _ = pkt.TransportHeader().Push(fakeTransHeaderLen) - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, err } @@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress(), route: route, } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } |