diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 105 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 1894 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 169 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 132 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 690 | ||||
-rw-r--r-- | pkg/tcpip/stack/tcp.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 73 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 33 |
20 files changed, 1856 insertions, 1553 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 395ff9a07..6c42ab29b 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -85,6 +85,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/internal/tcp", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/transport/tcpconntrack", @@ -95,7 +96,7 @@ go_library( go_test( name = "stack_x_test", - size = "medium", + size = "small", srcs = [ "addressable_endpoint_state_test.go", "ndp_test.go", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ce9cebdaa..7e4b5bf74 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -117,10 +117,10 @@ func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressS } // AddAndAcquirePermanentAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) { +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, properties, true /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -149,7 +149,7 @@ func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.Addr func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, tcpip.Error) { a.mu.Lock() defer a.mu.Unlock() - ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: peb}, false /* permanent */) // From https://golang.org/doc/faq#nil_error: // // Under the covers, interfaces are implemented as two elements, a type T and @@ -180,7 +180,7 @@ func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.Addr // returned, regardless the kind of address that is being added. // // Precondition: a.mu must be write locked. -func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, tcpip.Error) { +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, properties AddressProperties, permanent bool) (*addressState, tcpip.Error) { // attemptAddToPrimary is false when the address is already in the primary // address list. attemptAddToPrimary := true @@ -208,7 +208,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // We now promote the address. for i, s := range a.mu.primary { if s == addrState { - switch peb { + switch properties.PEB { case CanBePrimaryEndpoint: // The address is already in the primary address list. attemptAddToPrimary = false @@ -222,7 +222,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address case NeverPrimaryEndpoint: a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } break } @@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // or we are adding a new temporary or permanent address. // // The address MUST be write locked at this point. - defer addrState.mu.Unlock() + defer addrState.mu.Unlock() // +checklocksforce if permanent { if addrState.mu.kind.IsPermanent() { @@ -262,11 +262,11 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address } // Acquire the address before returning it. addrState.mu.refs++ - addrState.mu.deprecated = deprecated - addrState.mu.configType = configType + addrState.mu.deprecated = properties.Deprecated + addrState.mu.configType = properties.ConfigType if attemptAddToPrimary { - switch peb { + switch properties.PEB { case NeverPrimaryEndpoint: case CanBePrimaryEndpoint: a.mu.primary = append(a.mu.primary, addrState) @@ -285,7 +285,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address a.mu.primary[0] = addrState } default: - panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", properties.PEB)) } } @@ -489,12 +489,12 @@ func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tc // Proceed to add a new temporary endpoint. addr := localAddr.WithPrefix() - ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + ep, err := a.addAndAcquireAddressLocked(addr, AddressProperties{PEB: tempPEB}, false /* permanent */) if err != nil { // addAndAcquireAddressLocked only returns an error if the address is // already assigned but we just checked above if the address exists so we // expect no error. - panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, AddressProperties{PEB: %s}, false): %s", addr, tempPEB, err)) } // From https://golang.org/doc/faq#nil_error: diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 140f146f6..c55f85743 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -38,9 +38,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { } { - ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.AddressProperties{PEB: stack.NeverPrimaryEndpoint}) if err != nil { - t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, AddressProperties{PEB: NeverPrimaryEndpoint}): %s", addr, err) } // We don't need the address endpoint. ep.DecRef() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index f7fbcbaa7..068dab7ce 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -35,7 +35,6 @@ import ( // Currently, only TCP tracking is supported. // Our hash table has 16K buckets. -// TODO(gvisor.dev/issue/170): These should be tunable. const numBuckets = 1 << 14 // Direction of the tuple. @@ -165,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. - // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle - // other tcp states. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) } else if hook == cn.tcbHook { @@ -246,8 +243,7 @@ func (ct *ConnTrack) init() { // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support -// other transport protocols. +// TODO(gvisor.dev/issue/6168): Support UDP. func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { tid, err := packetToTupleID(pkt) if err != nil { @@ -367,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } } @@ -385,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/170): Support other transport protocols. + // TODO(gvisor.dev/issue/6168): Support UDP. if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -409,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. + var newAddr tcpip.Address + var newPort uint16 + + updateSRCFields := false + switch hook { case Prerouting, Output: if conn.manip == manipDestination { switch dir { case dirOriginal: - tcpHeader.SetDestinationPort(conn.reply.srcPort) - netHeader.SetDestinationAddress(conn.reply.srcAddr) + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr case dirReply: - tcpHeader.SetSourcePort(conn.original.dstPort) - netHeader.SetSourceAddress(conn.original.dstAddr) + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + + updateSRCFields = true } pkt.NatDone = true } @@ -426,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if conn.manip == manipSource { switch dir { case dirOriginal: - tcpHeader.SetSourcePort(conn.reply.dstPort) - netHeader.SetSourceAddress(conn.reply.dstAddr) + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + + updateSRCFields = true case dirReply: - tcpHeader.SetDestinationPort(conn.original.srcPort) - netHeader.SetDestinationAddress(conn.original.srcAddr) + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr } pkt.NatDone = true } @@ -441,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } + fullChecksum := false + updatePseudoHeader := false switch hook { case Prerouting, Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - tcpHeader.SetChecksum(xsum) + updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + fullChecksum = true + updatePseudoHeader = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } + rewritePacket( + netHeader, + tcpHeader, + updateSRCFields, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) // Update the state of tcb. - // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle - // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() @@ -544,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused returns the next bucket that should be checked and the time after // which it should be called again. func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { - // TODO(gvisor.dev/issue/170): This can be more finely controlled, as - // it is in Linux via sysctl. const fractionPerReaping = 128 const maxExpiredPct = 50 const maxFullTraversal = 60 * time.Second @@ -623,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo // Don't re-unlock if both tuples are in the same bucket. if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } return true diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 72f66441f..c2f1f4798 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -181,10 +181,6 @@ func (*fwdTestNetworkProtocol) MinimumPacketSize() int { return fwdTestNetHeaderLen } -func (*fwdTestNetworkProtocol) DefaultPrefixLen() int { - return fwdTestNetDefaultPrefixLen -} - func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } @@ -342,6 +338,10 @@ func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, p return n, nil } +func (*fwdTestLinkEndpoint) WriteRawPacket(*PacketBuffer) tcpip.Error { + return &tcpip.ErrNotSupported{} +} + // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -380,8 +380,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } // NIC 2 has the link address "b", and added the network address 2. @@ -393,8 +400,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.M if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fwdTestNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fwdTestNetDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } nic, ok := s.nics[2] diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 0a26f6dd8..f152c0d83 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -268,10 +268,6 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from -// which address can be gathered. Currently, address is only needed for -// prerouting. -// // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 2812c89aa..96cc899bb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than // forwarded). -// -// TODO(gvisor.dev/issue/170): Other flags need to be added after we support -// them. type RedirectTarget struct { // Port indicates port used to redirect. It is immutable. Port uint16 @@ -100,9 +97,6 @@ type RedirectTarget struct { } // Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for Prerouting and calls pkt.Clone(), neither -// of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { @@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.Port) - // Calculate UDP checksum and set it. if hook == Output { - udpHeader.SetChecksum(0) - netHeader := pkt.Network() - netHeader.SetDestinationAddress(address) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + udpHeader, + false, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + rt.Port, + address, + ) + } else { + udpHeader.SetDestinationPort(rt.Port) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetChecksum(0) - udpHeader.SetSourcePort(st.Port) - netHeader := pkt.Network() - netHeader.SetSourceAddress(st.Addr) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + header.UDP(pkt.TransportHeader().View()), + true, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + st.Port, + st.Addr, + ) - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleAccept, 0 } + +func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { + if updateSRCFields { + if fullChecksum { + t.SetSourcePortWithChecksumUpdate(newPort) + } else { + t.SetSourcePort(newPort) + } + } else { + if fullChecksum { + t.SetDestinationPortWithChecksumUpdate(newPort) + } else { + t.SetDestinationPort(newPort) + } + } + + if updatePseudoHeader { + var oldAddr tcpip.Address + if updateSRCFields { + oldAddr = n.SourceAddress() + } else { + oldAddr = n.DestinationAddress() + } + + t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum) + } + + if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok { + if updateSRCFields { + checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr) + } else { + checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr) + } + } else if updateSRCFields { + n.SetSourceAddress(newAddr) + } else { + n.SetDestinationAddress(newAddr) + } +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 93592e7f5..66e5f22ac 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -242,7 +242,6 @@ type IPHeaderFilter struct { func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( - // TODO(gvisor.dev/issue/170): Support other filter fields. transProto tcpip.TransportProtocolNumber dstAddr tcpip.Address srcAddr tcpip.Address @@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return true case Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 133bacdd0..40b33b6b5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -52,17 +52,6 @@ const ( linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -112,11 +101,13 @@ type ndpDADEvent struct { res stack.DADResult } -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool +type ndpOffLinkRouteEvent struct { + nicID tcpip.NICID + subnet tcpip.Subnet + router tcpip.Address + prf header.NDPRoutePreference + // true if route was updated, false if invalidated. + updated bool } type ndpPrefixEvent struct { @@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct { eventType ndpAutoGenAddrEventType } +func (e ndpAutoGenAddrEvent) String() string { + return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType) +} + type ndpRDNSS struct { addrs []tcpip.Address lifetime time.Duration @@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // related events happen for test purposes. type ndpDispatcher struct { dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool + offLinkRouteC chan ndpOffLinkRouteEvent prefixC chan ndpPrefixEvent - rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent dnsslC chan ndpDNSSLEvent @@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add } } -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. +func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, true, } } - - return n.rememberRouter } -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. +func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + var prf header.NDPRoutePreference + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, false, } } } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ nicID, @@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip true, } } - - return n.rememberPrefix } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. @@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ nicID, @@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi newAddr, } } - return true } func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { @@ -340,8 +333,12 @@ func TestDADDisabled(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Should get the address immediately since we should not have performed @@ -386,12 +383,15 @@ func TestDADResolveLoopback(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr1, - PrefixLen: defaultPrefixLen, + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: defaultPrefixLen, + }, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -524,8 +524,12 @@ func TestDADResolve(t *testing.T) { Address: addr1, PrefixLen: defaultPrefixLen, } - if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID, protocolAddr, err) } // Make sure the address does not resolve before the resolution time has @@ -747,8 +751,12 @@ func TestDADFail(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet @@ -785,8 +793,8 @@ func TestDADFail(t *testing.T) { // Attempting to add the address again should not fail if the address's // state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } }) } @@ -858,8 +866,12 @@ func TestDADStop(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). @@ -982,17 +994,29 @@ func TestSetNDPConfigurations(t *testing.T) { // Add addresses for each NIC. addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix1, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID1, protocolAddr1, err) } addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix2, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID2, protocolAddr2, err) } expectDADEvent(nicID2, addr2) addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen} - if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil { - t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix3, + } + if err := s.AddProtocolAddress(nicID3, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}) = %s", nicID3, protocolAddr3, err) } expectDADEvent(nicID3, addr3) @@ -1039,9 +1063,12 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { +// raBuf returns a valid NDP Router Advertisement with options, router +// preference and DHCPv6 configurations specified. +func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + const flagsByte = 1 + const routerLifetimeOffset = 2 + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -1050,19 +1077,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) + binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) // Populate the Managed Address flag field. if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= 1 << 7 + // The Managed Addresses flag field is the 7th bit of the flags byte. + raPayload[flagsByte] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= 1 << 6 + // The Other Configurations flag field is the 6th bit of the flags byte. + raPayload[flagsByte] |= 1 << 6 } + // The Prf field is held in the flags byte. + raPayload[flagsByte] |= byte(prf) << 3 opts := ra.Options() opts.Serialize(optSer) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -1090,7 +1117,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) } // raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related @@ -1098,18 +1125,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { + return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) } // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { +func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } +// raBufWithPrf returns a valid NDP Router Advertisement with a preference. +// +// Note, raBufWithPrf does not populate any of the RA fields other than the +// Router Lifetime and Default Router Preference fields. +func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) +} + // raBufWithPI returns a valid NDP Router Advertisement with a single Prefix // Information option. // @@ -1148,6 +1183,39 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } +// raBufWithRIO returns a valid NDP Router Advertisement with a single Route +// Information option. +// +// All fields in the RA will be zero except the RIO option. +func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer { + // buf will hold the route information option after the Type and Length + // fields. + // + // 2.3. Route Information Option + // + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Prefix Length |Resvd|Prf|Resvd| + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Route Lifetime | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Prefix (Variable Length) | + // . . + // . . + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + var buf [22]byte + buf[0] = uint8(prefix.PrefixLen) + buf[1] = byte(prf) << 3 + binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds) + if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) { + t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address)) + } + return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{ + header.NDPRouteInformation(buf[:]), + }) +} + func TestDynamicConfigurationsDisabled(t *testing.T) { const ( nicID = 1 @@ -1169,7 +1237,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { config: func(enable bool) ipv6.NDPConfigurations { return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} }, - ra: raBuf(llAddr2, 1000), + ra: raBufSimple(llAddr2, 1000), }, { name: "No Prefix Discovery", @@ -1205,9 +1273,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle @@ -1277,8 +1345,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) default: } select { @@ -1304,10 +1372,8 @@ func boolToUint64(v bool) uint64 { return 0 } -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1340,167 +1406,176 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) +func TestOffLinkRouteDiscovery(t *testing.T) { + const nicID = 1 - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16} + tests := []struct { + name string - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } + discoverDefaultRouters bool + discoverMoreSpecificRoutes bool - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - default: + dest tcpip.Subnet + ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer + }{ + { + name: "Default router discovery", + discoverDefaultRouters: true, + discoverMoreSpecificRoutes: false, + dest: header.IPv6EmptySubnet, + ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBufWithPrf(router, lifetimeSeconds, prf) + }, + }, + { + name: "More-specific route discovery", + discoverDefaultRouters: false, + discoverMoreSpecificRoutes: true, + dest: moreSpecificPrefix.Subnet(), + ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf) + }, + }, } -} -func TestRouterDiscovery(t *testing.T) { - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: test.discoverDefaultRouters, + DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } } - default: - t.Fatal("expected router discovery event") - } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - clock.Advance(timeout) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + clock.Advance(timeout) + select { + case e := <-ndpDisp.offLinkRouteC: + var prf header.NDPRoutePreference + if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("timed out waiting for router discovery event") + } } - default: - t.Fatal("timed out waiting for router discovery event") - } - } - - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference)) + select { + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") + default: + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + // Discover an off-link route through llAddr2. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference)) + if test.discoverMoreSpecificRoutes { + // The reserved value is considered invalid with more-specific route + // discovery so we inject the same packet but with the default + // (medium) preference value. + select { + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with a reserved preference value") + default: + } + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference)) + } + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) + + // Rx an RA from another router (lladdr3) with non-zero lifetime and + // non-default preference value. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) + expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) + + // Rx an RA from lladdr2 with lesser lifetime and default (medium) + // preference value. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference)) + select { + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") + default: + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with a different preference. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) - - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) - }) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) + }) + }) + } } // TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1513,23 +1588,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { })}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - if i <= ipv6.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredOffLinkRoutes { select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") @@ -1537,7 +1612,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } else { select { - case <-ndpDisp.routerC: + case <-ndpDisp.offLinkRouteC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") default: } @@ -1551,54 +1626,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) } -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - clock.Advance(lifetimeSeconds * time.Second) - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - default: - } -} - func TestPrefixDiscovery(t *testing.T) { prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") @@ -1606,8 +1633,7 @@ func TestPrefixDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) clock := faketime.NewManualClock() @@ -1697,17 +1723,6 @@ func TestPrefixDiscovery(t *testing.T) { } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - prefix := tcpip.AddressWithPrefix{ Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, @@ -1715,8 +1730,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) clock := faketime.NewManualClock() @@ -1750,9 +1764,9 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) - clock.Advance(testInfiniteLifetime) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1760,9 +1774,8 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - clock.Advance(testInfiniteLifetime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) + clock.Advance(header.NDPInfiniteLifetime - time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { @@ -1773,23 +1786,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - clock.Advance(testInfiniteLifetime) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - default: - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) - clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1806,8 +1809,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1884,17 +1886,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) } +const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second) +const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second) + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1903,6 +1900,7 @@ func TestAutoGenAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1911,6 +1909,7 @@ func TestAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { @@ -1960,8 +1959,9 @@ func TestAutoGenAddr(t *testing.T) { default: } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -1971,7 +1971,7 @@ func TestAutoGenAddr(t *testing.T) { } // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") @@ -1979,12 +1979,13 @@ func TestAutoGenAddr(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -2014,20 +2015,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t // TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when // configured to do so as part of IPv6 Privacy Extensions. func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond + const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -2047,218 +2035,211 @@ func TestAutoGenTempAddr(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + for i, test := range tests { + t.Run(test.name, func(t *testing.T) { + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, + Clock: clock, + }) - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("expected addr auto gen event") } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for DAD event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum and won't be reached in this test. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 } - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) + default: + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } } // TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not @@ -2266,12 +2247,6 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - tests := []struct { name string dupAddrTransmits uint8 @@ -2287,66 +2262,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, + })}, + Clock: clock, + }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") + default: + t.Fatal("expected addr auto gen event") + } + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD event") + } - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + default: + } + }) + } } // TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address @@ -2359,12 +2324,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) @@ -2375,6 +2334,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -2388,6 +2348,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2417,12 +2378,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // Wait for DAD to complete for the stable address then expect the temporary // address to be generated. + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } select { @@ -2430,7 +2392,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -2439,46 +2401,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // regenerated. func TestAutoGenTempAddrRegen(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrValidLifetime: maxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate, + } + clock := faketime.NewManualClock() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2501,36 +2461,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" { t.Fatal(mismatch) } + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false @@ -2541,45 +2508,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) { ndpEP.SetNDPConfigurations(ndpConfigs) } + // Refresh lifetimes and wait for the last temporary address to be deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv) + + // Refresh lifetimes such that the prefix is valid and preferred forever. + // + // This should not affect the lifetimes of temporary addresses because they + // are capped by the maximum valid and preferred lifetimes for temporary + // addresses. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) + // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter + invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{}) for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter + expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter) + invalidateAfter = tempAddrRegenenerationTime } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" { t.Fatal(mismatch) } } @@ -2588,52 +2534,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // regeneration job gets updated when refreshing the address's lifetimes. func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate + maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second) + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime, + MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2, + } + clock := faketime.NewManualClock() + initialTime := clock.NowMonotonic() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -2650,22 +2598,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2673,13 +2622,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // // A new temporary address should be generated after the regeneration // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0)) expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: + } + + effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + + // Advance the clock by the regeneration time but don't expect a new temporary + // address as the prefix is deprecated. + clock.Advance(tempAddrRegenenerationTime) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: } // Prefer the prefix again. @@ -2687,8 +2650,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) + expectAutoGenAddrEvent(tempAddrs[1], newAddr) + // Wait for the first temporary address to be deprecated. + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %s", e) + default: + } // Increase the maximum lifetimes for temporary addresses to large values // then refresh the lifetimes of the prefix. @@ -2699,34 +2669,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // regenerate a new temporary address. Note, new addresses are only // regenerated after the preferred lifetime - the regenerate advance duration // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + const largeLifetimeSeconds = minVLSeconds * 2 + const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second + ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime + ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } ndpEP := ipv6Ep.(ipv6.NDPEndpoint) ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime) + clock.Advance(largeLifetime - timeSinceInitialTime) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) + // to offset the advement of time to test the first temporary address's + // deprecation after the second was generated + advLess := regenAdv + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv)) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + default: } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2853,8 +2819,12 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { continue } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: test.addrs[j].Address.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } manuallyAssignedAddresses[test.addrs[j].Address] = struct{}{} @@ -2954,13 +2924,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2970,6 +2941,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2983,7 +2955,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } - return ndpDisp, e, s + return ndpDisp, e, s, clock } // addrForNewConnectionTo returns the local address used when creating a new @@ -3057,7 +3029,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3160,19 +3132,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // when its preferred lifetime expires. func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3190,12 +3154,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -3213,7 +3178,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) @@ -3232,7 +3197,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3241,7 +3206,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3251,6 +3216,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since addr1 is deprecated but // addr2 is not. expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { @@ -3259,7 +3225,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3271,7 +3237,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3281,7 +3247,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3295,7 +3261,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3305,7 +3271,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr2) // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3317,6 +3283,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // cases because the deprecation and invalidation handlers could be handled in // either deprecation then invalidation, or invalidation then deprecation // (which should be cancelled by the invalidation handler). + // + // Since we're about to cause both events to fire, we need the dispatcher + // channel to be able to hold both. + if got, want := len(ndpDisp.autoGenAddrC), 0; got != want { + t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want { + t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2) + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { @@ -3327,21 +3304,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation + // If we get an invalidation event first, we should not get a deprecation // event after. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3378,15 +3355,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // infinite values. func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3410,68 +3378,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + default: + t.Fatal("expected addr auto gen event") + } - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an @@ -3479,12 +3437,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3495,137 +3447,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { evl uint32 }{ // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than + // new VL is less than minVLSeconds but was originally greater than // it. { "LargeVLToVLLessThanMinVLForUpdate", 9999, 1, - newMinVL, + minVLSeconds, }, { "LargeVLTo0", 9999, 0, - newMinVL, + minVLSeconds, }, { "InfiniteVLToVLLessThanMinVLForUpdate", infiniteVL, 1, - newMinVL, + minVLSeconds, }, { "InfiniteVLTo0", infiniteVL, 0, - newMinVL, + minVLSeconds, }, - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. + // Should not update VL if original VL was less than minVLSeconds + // and the new VL is also less than minVLSeconds. { "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, + minVLSeconds - 1, + minVLSeconds - 3, + minVLSeconds - 1, }, // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. + // remaining time or is greater than minVLSeconds. { "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, + minVLSeconds + 5, + minVLSeconds + 3, + minVLSeconds + 3, }, { "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, + minVLSeconds - 3, + minVLSeconds - 1, + minVLSeconds - 1, }, { "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, + minVLSeconds - 3, + minVLSeconds + 1, + minVLSeconds + 1, }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // - // Validate that the VL for the address got set - // to test.evl. - // + // + // Validate that the VL for the address got set + // to test.evl. + // - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } + // The address should not be invalidated until the effective valid + // lifetime has passed. + const delta = 1 + clock.Advance(time.Duration(test.evl)*time.Second - delta) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + default: + } - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + // Wait for the invalidation event. + clock.Advance(delta) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -3696,7 +3640,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3735,8 +3679,9 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr2, } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + properties := stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protoAddr2, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) = %s", nicID, protoAddr2, properties, err) } // addr2 should be more preferred now since it is at the front of the primary // list. @@ -3824,8 +3769,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } // Add the address as a static address before SLAAC tries to add it. - if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr} + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) = %s", protocolAddr, err) } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -3976,13 +3922,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4008,22 +3947,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { @@ -4034,15 +3975,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } } @@ -4053,7 +3995,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name string ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { @@ -4062,7 +4004,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { // Receive an RA with prefix1 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) return nil @@ -4076,7 +4018,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil }, addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { @@ -4090,14 +4032,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { header.InitialTempIID(tempIIDHistory, nil, nicID) // Generate a stable SLAAC address so temporary addresses will be // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) + expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -4109,14 +4051,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. t.Run(addrType.name, func(t *testing.T) { for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { @@ -4125,8 +4059,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrType := addrType t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), @@ -4134,6 +4066,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { e := channel.New(0, 1280, linkAddr1) ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4150,6 +4083,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4157,12 +4091,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:]) // Simulate DAD conflicts so the address is regenerated. for i := uint8(0); i < numFailures; i++ { addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) // Should not have any new addresses assigned to the NIC. if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { @@ -4172,17 +4106,21 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4194,8 +4132,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // an address after DAD resolves. if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4205,7 +4143,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4718,11 +4656,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ @@ -4765,17 +4701,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ), ) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) } select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) @@ -4797,8 +4733,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpected router event = %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpected off-link route event = %#v", e) default: } select { @@ -4884,11 +4820,9 @@ func TestCleanupNDPState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } clock := faketime.NewManualClock() s := stack.New(stack.Options{ @@ -4905,14 +4839,14 @@ func TestCleanupNDPState(t *testing.T) { Clock: clock, }) - expectRouterEvent := func() (bool, ndpRouterEvent) { + expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { select { - case e := <-ndpDisp.routerC: + case e := <-ndpDisp.offLinkRouteC: return true, e default: } - return false, ndpRouterEvent{} + return false, ndpOffLinkRouteEvent{} } expectPrefixEvent := func() (bool, ndpPrefixEvent) { @@ -4957,8 +4891,8 @@ func TestCleanupNDPState(t *testing.T) { // multiple addresses. e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -4968,8 +4902,8 @@ func TestCleanupNDPState(t *testing.T) { } e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -4979,8 +4913,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -4990,8 +4924,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -5032,14 +4966,14 @@ func TestCleanupNDPState(t *testing.T) { test.cleanupFn(t, s) // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) + gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() + ok, e := expectOffLinkRouteEvent() if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) break } - gotRouterEvents[e]++ + gotOffLinkRouteEvents[e]++ } gotPrefixEvents := make(map[ndpPrefixEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { @@ -5066,14 +5000,14 @@ func TestCleanupNDPState(t *testing.T) { t.FailNow() } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { + t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) } expectedPrefixEvents := map[ndpPrefixEvent]int{ {nicID: nicID1, prefix: subnet1, discovered: false}: 1, @@ -5137,8 +5071,8 @@ func TestCleanupNDPState(t *testing.T) { // cancelled when the NDP state was cleaned up). clock.Advance(lifetimeSeconds * time.Second) select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") + case <-ndpDisp.offLinkRouteC: + t.Error("unexpected off-link route event") default: } select { @@ -5163,7 +5097,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { ndpDisp := ndpDispatcher{ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -5464,16 +5397,25 @@ func TestRouterSolicitation(t *testing.T) { RandSource: &randSource, }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + opts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, &e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) } if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining != 0 { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 378389db2..29d580e76 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -72,9 +72,15 @@ type nic struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained packetEndpointList are - // not. - packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList + } + + packetEPs struct { + mu sync.RWMutex + + // eps is protected by the mutex, but the values contained in it are not. + // + // +checklocks:mu + eps map[tcpip.NetworkProtocolNumber]*packetEndpointList } } @@ -91,6 +97,8 @@ type packetEndpointList struct { mu sync.RWMutex // eps is protected by mu, but the contained PacketEndpoint values are not. + // + // +checklocks:mu eps []PacketEndpoint } @@ -111,6 +119,12 @@ func (p *packetEndpointList) remove(ep PacketEndpoint) { } } +func (p *packetEndpointList) len() int { + p.mu.RLock() + defer p.mu.RUnlock() + return len(p.eps) +} + // forEach calls fn with each endpoints in p while holding the read lock on p. func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { p.mu.RLock() @@ -143,18 +157,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), } nic.linkResQueue.init(nic) - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) + + nic.packetEPs.mu.Lock() + defer nic.packetEPs.mu.Unlock() + + nic.packetEPs.eps = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) resolutionRequired := ep.Capabilities()&CapabilityResolutionRequired != 0 - // Register supported packet and network endpoint protocols. - for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = new(packetEndpointList) - } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = new(packetEndpointList) - netEP := netProto.NewEndpoint(nic, nic) nic.networkEndpoints[netNum] = netEP @@ -365,6 +377,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pkt.EgressRoute = r pkt.NetworkProtocolNumber = protocol + n.deliverOutboundPacket(r.RemoteLinkAddress, pkt) + if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil { return err } @@ -383,6 +397,7 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r pkt.NetworkProtocolNumber = protocol + n.deliverOutboundPacket(r.RemoteLinkAddress, pkt) } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) @@ -501,7 +516,7 @@ func (n *nic) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, // addAddress adds a new address to n, so that it starts accepting packets // targeted at the given address (and network protocol). -func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return &tcpip.ErrUnknownProtocol{} @@ -512,7 +527,7 @@ func (n *nic) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return &tcpip.ErrNotSupported{} } - addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, properties) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -699,12 +714,9 @@ func (n *nic) isInGroup(addr tcpip.Address) bool { // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - n.mu.RLock() enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { - n.mu.RUnlock() - n.stats.disabledRx.packets.Increment() n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return @@ -715,7 +727,6 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { - n.mu.RUnlock() n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -727,44 +738,87 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 - // Are any packet type sockets listening for this network protocol? - protoEPs := n.mu.packetEPs[protocol] - // Other packet type sockets that are listening for all protocols. - anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] - n.mu.RUnlock() - // Deliver to interested packet endpoints without holding NIC lock. + var packetEPPkt *PacketBuffer deliverPacketEPs := func(ep PacketEndpoint) { - p := pkt.Clone() - p.PktType = tcpip.PacketHost - ep.HandlePacket(n.id, local, protocol, p) + if packetEPPkt == nil { + // Packet endpoints hold the full packet. + // + // We perform a deep copy because higher-level endpoints may point to + // the middle of a view that is held by a packet endpoint. Save/Restore + // does not support overlapping slices and will panic in this case. + // + // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports + // overlapping slices (e.g. by passing a shallow copy of pkt to the packet + // endpoint). + packetEPPkt = NewPacketBuffer(PacketBufferOptions{ + Data: PayloadSince(pkt.LinkHeader()).ToVectorisedView(), + }) + // If a link header was populated in the original packet buffer, then + // populate it in the packet buffer we provide to packet endpoints as + // packet endpoints inspect link headers. + packetEPPkt.LinkHeader().Consume(pkt.LinkHeader().View().Size()) + packetEPPkt.PktType = tcpip.PacketHost + } + + ep.HandlePacket(n.id, local, protocol, packetEPPkt.Clone()) } - if protoEPs != nil { + + n.packetEPs.mu.Lock() + // Are any packet type sockets listening for this network protocol? + protoEPs, protoEPsOK := n.packetEPs.eps[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs, anyEPsOK := n.packetEPs.eps[header.EthernetProtocolAll] + n.packetEPs.mu.Unlock() + + if protoEPsOK { protoEPs.forEach(deliverPacketEPs) } - if anyEPs != nil { + if anyEPsOK { anyEPs.forEach(deliverPacketEPs) } networkEndpoint.HandlePacket(pkt) } -// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. -func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - n.mu.RLock() +// deliverOutboundPacket delivers outgoing packets to interested endpoints. +func (n *nic) deliverOutboundPacket(remote tcpip.LinkAddress, pkt *PacketBuffer) { + n.packetEPs.mu.RLock() + defer n.packetEPs.mu.RUnlock() // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - eps := n.mu.packetEPs[header.EthernetProtocolAll] - n.mu.RUnlock() + eps, ok := n.packetEPs.eps[header.EthernetProtocolAll] + if !ok { + return + } + + local := n.LinkAddress() + var packetEPPkt *PacketBuffer eps.forEach(func(ep PacketEndpoint) { - p := pkt.Clone() - p.PktType = tcpip.PacketOutgoing - // Add the link layer header as outgoing packets are intercepted - // before the link layer header is created. - n.LinkEndpoint.AddHeader(local, remote, protocol, p) - ep.HandlePacket(n.id, local, protocol, p) + if packetEPPkt == nil { + // Packet endpoints hold the full packet. + // + // We perform a deep copy because higher-level endpoints may point to + // the middle of a view that is held by a packet endpoint. Save/Restore + // does not support overlapping slices and will panic in this case. + // + // TODO(https://gvisor.dev/issue/6517): Avoid this copy once S/R supports + // overlapping slices (e.g. by passing a shallow copy of pkt to the packet + // endpoint). + packetEPPkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: pkt.AvailableHeaderBytes(), + Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), + }) + // Add the link layer header as outgoing packets are intercepted before + // the link layer header is created and packet endpoints are interested + // in the link header. + n.LinkEndpoint.AddHeader(local, remote, pkt.NetworkProtocolNumber, packetEPPkt) + packetEPPkt.PktType = tcpip.PacketOutgoing + } + + ep.HandlePacket(n.id, local, pkt.NetworkProtocolNumber, packetEPPkt.Clone()) }) } @@ -779,17 +833,11 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt transProto := state.proto - // Raw socket packets are delivered based solely on the transport - // protocol number. We do not inspect the payload to ensure it's - // validly formed. - n.stack.demux.deliverRawPacket(protocol, pkt) - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. @@ -878,6 +926,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } } +// DeliverRawPacket implements TransportDispatcher. +func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { + // For ICMPv4 only we validate the header length for compatibility with + // raw(7) ICMP_FILTER. The same check is made in Linux here: + // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189. + if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize { + return + } + n.stack.demux.deliverRawPacket(protocol, pkt) +} + // ID implements NetworkInterface. func (n *nic) ID() tcpip.NICID { return n.id @@ -912,12 +971,13 @@ func (n *nic) setNUDConfigs(protocol tcpip.NetworkProtocolNumber, c NUDConfigura } func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() + n.packetEPs.mu.Lock() + defer n.packetEPs.mu.Unlock() - eps, ok := n.mu.packetEPs[netProto] + eps, ok := n.packetEPs.eps[netProto] if !ok { - return &tcpip.ErrNotSupported{} + eps = new(packetEndpointList) + n.packetEPs.eps[netProto] = eps } eps.add(ep) @@ -925,14 +985,17 @@ func (n *nic) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa } func (n *nic) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) { - n.mu.Lock() - defer n.mu.Unlock() + n.packetEPs.mu.Lock() + defer n.packetEPs.mu.Unlock() - eps, ok := n.mu.packetEPs[netProto] + eps, ok := n.packetEPs.eps[netProto] if !ok { return } eps.remove(ep) + if eps.len() == 0 { + delete(n.packetEPs.eps, netProto) + } } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 5cb342f78..c8ad93f29 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -127,11 +127,6 @@ func (*testIPv6Protocol) MinimumPacketSize() int { return header.IPv6MinimumSize } -// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. -func (*testIPv6Protocol) DefaultPrefixLen() int { - return header.IPv6AddressSize * 8 -} - // ParseAddresses implements NetworkProtocol.ParseAddresses. func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9192d8433..29c22bfd4 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -282,14 +282,12 @@ func (pk *PacketBuffer) headerView(typ headerType) tcpipbuffer.View { return v } -// Clone makes a shallow copy of pk. -// -// Clone should be called in such cases so that no modifications is done to -// underlying packet payload. +// Clone makes a semi-deep copy of pk. The underlying packet payload is +// shared. Hence, no modifications is done to underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - buf: pk.buf, + buf: pk.buf.Clone(), reserved: pk.reserved, pushed: pk.pushed, consumed: pk.consumed, @@ -321,14 +319,14 @@ func (pk *PacketBuffer) Network() header.Network { } } -// CloneToInbound makes a shallow copy of the packet buffer to be used as an -// inbound packet. +// CloneToInbound makes a semi-deep copy of the packet buffer (similar to +// Clone) to be used as an inbound packet. // // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { newPk := &PacketBuffer{ - buf: pk.buf, + buf: pk.buf.Clone(), // Treat unfilled header portion as reserved. reserved: pk.AvailableHeaderBytes(), } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index a8da34992..87b023445 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,6 +123,32 @@ func TestPacketHeaderPush(t *testing.T) { } } +func TestPacketBufferClone(t *testing.T) { + data := concatViews(makeView(20), makeView(30), makeView(40)) + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + bytesToDelete := 30 + originalSize := data.Size() + + clonedPks := []*PacketBuffer{ + pk.Clone(), + pk.CloneToInbound(), + } + pk.Data().DeleteFront(bytesToDelete) + if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { + t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) + } + for _, clonedPk := range clonedPks { + if got := clonedPk.Data().Size(); got != originalSize { + t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) + } + } +} + func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index a038389e0..31b3a554d 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -265,6 +265,11 @@ type TransportDispatcher interface { // // DeliverTransportError takes ownership of the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) + + // DeliverRawPacket delivers a packet to any subscribed raw sockets. + // + // DeliverRawPacket does NOT take ownership of the packet buffer. + DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -313,8 +318,7 @@ type PrimaryEndpointBehavior int const ( // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. + // endpoint for new connections with no local address. CanBePrimaryEndpoint PrimaryEndpointBehavior = iota // FirstPrimaryEndpoint indicates the endpoint should be the first @@ -327,6 +331,19 @@ const ( NeverPrimaryEndpoint ) +func (peb PrimaryEndpointBehavior) String() string { + switch peb { + case CanBePrimaryEndpoint: + return "CanBePrimaryEndpoint" + case FirstPrimaryEndpoint: + return "FirstPrimaryEndpoint" + case NeverPrimaryEndpoint: + return "NeverPrimaryEndpoint" + default: + panic(fmt.Sprintf("unknown primary endpoint behavior: %d", peb)) + } +} + // AddressConfigType is the method used to add an address. type AddressConfigType int @@ -346,6 +363,14 @@ const ( AddressConfigSlaacTemp ) +// AddressProperties contains additional properties that can be configured when +// adding an address. +type AddressProperties struct { + PEB PrimaryEndpointBehavior + ConfigType AddressConfigType + Deprecated bool +} + // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { @@ -452,7 +477,7 @@ type AddressableEndpoint interface { // Returns *tcpip.ErrDuplicateAddress if the address exists. // // Acquires and returns the AddressEndpoint for the added address. - AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, tcpip.Error) + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, properties AddressProperties) (AddressEndpoint, tcpip.Error) // RemovePermanentAddress removes the passed address if it is a permanent // address. @@ -680,9 +705,6 @@ type NetworkProtocol interface { // than this targeted at this protocol. MinimumPacketSize() int - // DefaultPrefixLen returns the protocol's default prefix length. - DefaultPrefixLen() int - // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) @@ -728,16 +750,6 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) - - // DeliverOutboundPacket is called by link layer when a packet is being - // sent out. - // - // pkt.LinkHeader may or may not be set before calling - // DeliverOutboundPacket. Some packets do not have link headers (e.g. - // packets sent via loopback), and won't have the field set. - // - // DeliverOutboundPacket takes ownership of pkt. - DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -841,6 +853,14 @@ type LinkEndpoint interface { // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + + // WriteRawPacket writes a packet directly to the link. + // + // If the link-layer has its own header, the payload must already include the + // header. + // + // WriteRawPacket takes ownership of the packet. + WriteRawPacket(*PacketBuffer) tcpip.Error } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 40d277312..98867a828 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -72,7 +72,8 @@ type Stack struct { // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. - rawFactory RawFactory + rawFactory RawFactory + packetEndpointWriteSupported bool demux *transportDemuxer @@ -108,7 +109,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - // TODO(gvisor.dev/issue/170): S/R this field. + // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -119,8 +120,7 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // seed is a one-time random value initialized at stack startup - // and is used to seed the TCP port picking on active connections + // seed is a one-time random value initialized at stack startup. // // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 @@ -161,6 +161,10 @@ type Stack struct { // This is required to prevent potential ACK loops. // Setting this to 0 will disable all rate limiting. tcpInvalidRateLimit time.Duration + + // tsOffsetSecret is the secret key for generating timestamp offsets + // initialized at stack startup. + tsOffsetSecret uint32 } // UniqueID is an abstract generator of unique identifiers. @@ -215,6 +219,10 @@ type Options struct { // this is non-nil. RawFactory RawFactory + // AllowPacketEndpointWrite determines if packet endpoints support write + // operations. + AllowPacketEndpointWrite bool + // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by the stack secure RNG. @@ -356,23 +364,24 @@ func New(opts Options) *Stack { opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: seed, - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: randomGenerator, - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + packetEndpointWriteSupported: opts.AllowPacketEndpointWrite, + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: seed, + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: randomGenerator, + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -384,6 +393,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, tcpInvalidRateLimit: defaultTCPInvalidRateLimit, + tsOffsetSecret: randomGenerator.Uint32(), } // Add specified network protocols. @@ -780,6 +790,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { if !ok { return &tcpip.ErrUnknownNICID{} } + if nic.IsLoopback() { + return &tcpip.ErrNotSupported{} + } delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. @@ -903,46 +916,9 @@ type NICStateFlags struct { Loopback bool } -// AddAddress adds a new network-layer address to the specified NIC. -func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.Error { - return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) -} - -// AddAddressWithPrefix is the same as AddAddress, but allows you to specify -// the address prefix. -func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) tcpip.Error { - ap := tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: addr, - } - return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) -} - -// AddProtocolAddress adds a new network-layer protocol address to the -// specified NIC. -func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) tcpip.Error { - return s.AddProtocolAddressWithOptions(id, protocolAddress, CanBePrimaryEndpoint) -} - -// AddAddressWithOptions is the same as AddAddress, but allows you to specify -// whether the new endpoint can be primary or not. -func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) tcpip.Error { - netProto, ok := s.networkProtocols[protocol] - if !ok { - return &tcpip.ErrUnknownProtocol{} - } - return s.AddProtocolAddressWithOptions(id, tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb) -} - -// AddProtocolAddressWithOptions is the same as AddProtocolAddress, but allows -// you to specify whether the new endpoint can be primary or not. -func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) tcpip.Error { +// AddProtocolAddress adds an address to the specified NIC, possibly with extra +// properties. +func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress, properties AddressProperties) tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() @@ -951,7 +927,7 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return &tcpip.ErrUnknownNICID{} } - return nic.addAddress(protocolAddress, peb) + return nic.addAddress(protocolAddress, properties) } // RemoveAddress removes an existing network-layer address from the specified @@ -1646,9 +1622,27 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) + pkt.NetworkProtocolNumber = netProto return nic.WritePacketToRemote(remote, netProto, pkt) } +// WriteRawPacket writes data directly to the specified NIC without adding any +// headers. +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + if !ok { + return &tcpip.ErrUnknownNICID{} + } + + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: payload, + }) + pkt.NetworkProtocolNumber = proto + return nic.WriteRawPacket(pkt) +} + // NetworkProtocolInstance returns the protocol instance in the stack for the // specified network protocol. This method is public for protocol implementers // and tests to use. @@ -1816,8 +1810,7 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, proto tcpip.NetworkProtocol return nic.setNUDConfigs(proto, c) } -// Seed returns a 32 bit value that can be used as a seed value for port -// picking, ISN generation etc. +// Seed returns a 32 bit value that can be used as a seed value. // // NOTE: The seed is generated once during stack initialization only. func (s *Stack) Seed() uint32 { @@ -1872,9 +1865,8 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } @@ -1942,3 +1934,9 @@ func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProto return false } + +// PacketEndpointWriteSupported returns true iff packet endpoints support write +// operations. +func (s *Stack) PacketEndpointWriteSupported() bool { + return s.packetEndpointWriteSupported +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 21951d05a..cd4137794 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -234,10 +234,6 @@ func (*fakeNetworkProtocol) MinimumPacketSize() int { return fakeNetHeaderLen } -func (*fakeNetworkProtocol) DefaultPrefixLen() int { - return fakeDefaultPrefixLen -} - func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { return f.packetCount[int(intfAddr)%len(f.packetCount)] } @@ -349,12 +345,26 @@ func TestNetworkReceive(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr2, err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) @@ -517,8 +527,15 @@ func TestNetworkSend(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Make sure that the link-layer endpoint received the outbound packet. @@ -538,12 +555,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -551,12 +582,26 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -719,38 +764,59 @@ func TestRemoveUnknownNIC(t *testing.T) { } func TestRemoveNIC(t *testing.T) { - const nicID = 1 + for _, tt := range []struct { + name string + linkep stack.LinkEndpoint + expectErr tcpip.Error + }{ + { + name: "loopback", + linkep: loopback.New(), + expectErr: &tcpip.ErrNotSupported{}, + }, + { + name: "channel", + linkep: channel.New(0, defaultMTU, ""), + expectErr: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + e := linkEPWithMockedAttach{ + LinkEndpoint: tt.linkep, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want { + t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want) + } + if tt.expectErr == nil { + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } + } + }) } } @@ -791,8 +857,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } ep2 := channel.New(1, defaultMTU, "") @@ -800,8 +873,15 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) } - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr2, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } // Set a route table that sends all packets with odd destination @@ -957,12 +1037,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr1, err) } - if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x03", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr3, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr3, err) } ep2 := channel.New(10, defaultMTU, "") @@ -970,12 +1064,26 @@ func TestRoutes(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x02", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr2, err) } - if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr4 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x04", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(2, protocolAddr4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 2, protocolAddr4, err) } // Set a route table that sends all packets with odd destination @@ -1037,8 +1145,15 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1087,8 +1202,15 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -1221,8 +1343,15 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1249,8 +1378,8 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1289,8 +1418,8 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) @@ -1432,8 +1561,15 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) @@ -1489,8 +1625,15 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { @@ -1612,8 +1755,8 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} - if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) + if err := s.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { @@ -1657,13 +1800,13 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { t.Fatalf("CreateNIC failed: %s", err) } nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} - if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) + if err := s.AddProtocolAddress(1, nic1ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}) failed: %s", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} - if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) + if err := s.AddProtocolAddress(2, nic2ProtoAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(2, %+v, {}) failed: %s", nic2ProtoAddr, err) } // Set the initial route table. @@ -1705,7 +1848,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: header.IPv4Broadcast.WithPrefix().Subnet(), NIC: 1}, }, rt..., ) @@ -1787,8 +1930,15 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %v", anyAddr, tc.address, fakeNetNumber, err, want) } - if err := s.AddAddress(1, fakeNetNumber, anyAddr); err != nil { - t.Fatalf("AddAddress(%v, %v) failed: %v", fakeNetNumber, anyAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: anyAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } if r, err := s.FindRoute(1, anyAddr, tc.address, fakeNetNumber, false /* multicastLoop */); tc.routeNeeded { @@ -1865,22 +2015,27 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Add an address and in case of a primary one include a // prefixLen. address := tcpip.Address(bytes.Repeat([]byte{byte(i)}, addrLen)) + properties := stack.AddressProperties{PEB: behavior} if behavior == stack.CanBePrimaryEndpoint { protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: addrLen * 8, - }, + Protocol: fakeNetNumber, + AddressWithPrefix: address.WithPrefix(), } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, protocolAddress, behavior, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } // Remember the address/prefix. primaryAddrAdded[protocolAddress.AddressWithPrefix] = struct{}{} } else { - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s:", nicID, fakeNetNumber, address, behavior, err) + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddress, properties, err) } } } @@ -1975,8 +2130,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { PrefixLen: tc.prefixLen, }, } - if err := s.AddProtocolAddress(1, protocolAddress); err != nil { - t.Fatal("AddProtocolAddress failed:", err) + if err := s.AddProtocolAddress(1, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(1, %+v, {}): %s", protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -2026,33 +2181,6 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } } -func TestAddAddress(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - var addrGen addressGenerator - expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) - for _, addrLen := range []int{4, 16} { - address := addrGen.next(addrLen) - if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { - t.Fatalf("AddAddress(address=%s) failed: %s", address, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ @@ -2063,96 +2191,43 @@ func TestAddProtocolAddress(t *testing.T) { t.Fatal("CreateNIC failed:", err) } - var addrGen addressGenerator - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)) - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) - } - expectedAddresses = append(expectedAddresses, protocolAddress) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - addrLenRange := []int{4, 16} behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)) + configTypeRange := []stack.AddressConfigType{stack.AddressConfigStatic, stack.AddressConfigSlaac, stack.AddressConfigSlaacTemp} + deprecatedRange := []bool{false, true} + wantAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(behaviorRange)*len(configTypeRange)*len(deprecatedRange)) var addrGen addressGenerator for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { - address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { - t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) - } - expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, - }) - } - } - - gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) -} - -func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addrLenRange := []int{4, 16} - prefixLenRange := []int{8, 13, 20, 32} - behaviorRange := []stack.PrimaryEndpointBehavior{stack.CanBePrimaryEndpoint, stack.FirstPrimaryEndpoint, stack.NeverPrimaryEndpoint} - expectedAddresses := make([]tcpip.ProtocolAddress, 0, len(addrLenRange)*len(prefixLenRange)*len(behaviorRange)) - var addrGen addressGenerator - for _, addrLen := range addrLenRange { - for _, prefixLen := range prefixLenRange { - for _, behavior := range behaviorRange { - protocolAddress := tcpip.ProtocolAddress{ - Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addrGen.next(addrLen), - PrefixLen: prefixLen, - }, - } - if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) + for _, configType := range configTypeRange { + for _, deprecated := range deprecatedRange { + address := addrGen.next(addrLen) + properties := stack.AddressProperties{ + PEB: behavior, + ConfigType: configType, + Deprecated: deprecated, + } + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v) failed: %s", nicID, protocolAddr, properties, err) + } + wantAddresses = append(wantAddresses, tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, + }) } - expectedAddresses = append(expectedAddresses, protocolAddress) } } } gotAddresses := s.AllAddresses()[nicID] - verifyAddresses(t, expectedAddresses, gotAddresses) + verifyAddresses(t, wantAddresses, gotAddresses) } func TestCreateNICWithOptions(t *testing.T) { @@ -2269,8 +2344,15 @@ func TestNICStats(t *testing.T) { if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed: ", err) } - if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { - t.Fatal("AddAddress failed:", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: nic.addr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicid, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicid, protocolAddr, err) } { @@ -2714,8 +2796,16 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // be returned by a call to GetMainNICAddress; // else, it should. const address1 = tcpip.Address("\x01") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, pi); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + properties := stack.AddressProperties{PEB: pi} + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr, properties, err) } addr, err := s.GetMainNICAddress(nicID, fakeNetNumber) if err != nil { @@ -2764,16 +2854,31 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. const address3 = tcpip.Address("\x03") - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address3, stack.FirstPrimaryEndpoint, err) - + protocolAddr3 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address3, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: stack.FirstPrimaryEndpoint} + if err := s.AddProtocolAddress(nicID, protocolAddr3, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr3, properties, err) } // Add back the address we removed earlier and // make sure the new peb was respected. // (The address should just be promoted now). - if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address1, ps); err != nil { - t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, fakeNetNumber, address1, pi, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address1, + PrefixLen: fakeDefaultPrefixLen, + }, + } + properties = stack.AddressProperties{PEB: ps} + if err := s.AddProtocolAddress(nicID, protocolAddr1, properties); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, %+v): %s", nicID, protocolAddr1, properties, err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[nicID].ProtocolAddresses { @@ -3075,8 +3180,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } for _, a := range test.nicAddrs { - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { - t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: a.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } } @@ -3182,8 +3291,12 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: addr1.WithPrefix(), + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } // The NIC should have joined addr1's solicited node multicast address. @@ -3338,8 +3451,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { PrefixLen: 128, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } // Address should be in the list of all addresses. @@ -3666,8 +3779,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { if err := s.CreateNIC(nicID1, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) } - if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + if err := s.AddProtocolAddress(nicID1, test.nicAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, test.nicAddr, err) } s.SetRouteTable(test.routes) @@ -3729,8 +3842,8 @@ func TestResolveWith(t *testing.T) { PrefixLen: 24, }, } - if err := s.AddProtocolAddress(nicID, addr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + if err := s.AddProtocolAddress(nicID, addr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, addr, err) } s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) @@ -3771,8 +3884,15 @@ func TestRouteReleaseAfterAddrRemoval(t *testing.T) { if err := s.CreateNIC(nicID, ep); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: localAddr, + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) } { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -3860,8 +3980,8 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { PrefixLen: 8, }, } - if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + if err := s.AddProtocolAddress(nicID, protocolAddress, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddress, err) } // Check that we get the right initial address and prefix length. @@ -3969,44 +4089,44 @@ func TestFindRouteWithForwarding(t *testing.T) { ) type netCfg struct { - proto tcpip.NetworkProtocolNumber - factory stack.NetworkProtocolFactory - nic1Addr tcpip.Address - nic2Addr tcpip.Address - remoteAddr tcpip.Address + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1AddrWithPrefix tcpip.AddressWithPrefix + nic2AddrWithPrefix tcpip.AddressWithPrefix + remoteAddr tcpip.Address } fakeNetCfg := netCfg{ - proto: fakeNetNumber, - factory: fakeNetFactory, - nic1Addr: nic1Addr, - nic2Addr: nic2Addr, - remoteAddr: remoteAddr, + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic1Addr, PrefixLen: fakeDefaultPrefixLen}, + nic2AddrWithPrefix: tcpip.AddressWithPrefix{Address: nic2Addr, PrefixLen: fakeDefaultPrefixLen}, + remoteAddr: remoteAddr, } globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: llAddr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: globalIPv6Addr1, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: llAddr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: globalIPv6Addr1, } ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: llAddr1, - remoteAddr: llAddr2, + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: llAddr1.WithPrefix(), + remoteAddr: llAddr2, } ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ - proto: ipv6.ProtocolNumber, - factory: ipv6.NewProtocol, - nic1Addr: globalIPv6Addr1, - nic2Addr: globalIPv6Addr2, - remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1AddrWithPrefix: globalIPv6Addr1.WithPrefix(), + nic2AddrWithPrefix: globalIPv6Addr2.WithPrefix(), + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", } tests := []struct { @@ -4015,8 +4135,8 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg netCfg forwardingEnabled bool - addrNIC tcpip.NICID - localAddr tcpip.Address + addrNIC tcpip.NICID + localAddrWithPrefix tcpip.AddressWithPrefix findRouteErr tcpip.Error dependentOnForwarding bool @@ -4026,7 +4146,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4035,7 +4155,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4044,7 +4164,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4053,7 +4173,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID1, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4062,7 +4182,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4071,7 +4191,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4080,7 +4200,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: false, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4089,7 +4209,7 @@ func TestFindRouteWithForwarding(t *testing.T) { netCfg: fakeNetCfg, forwardingEnabled: true, addrNIC: nicID2, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4097,7 +4217,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4105,7 +4225,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on same NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic2Addr, + localAddrWithPrefix: fakeNetCfg.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4113,7 +4233,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: false, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4121,7 +4241,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and localAddr on different NIC as route", netCfg: fakeNetCfg, forwardingEnabled: true, - localAddr: fakeNetCfg.nic1Addr, + localAddrWithPrefix: fakeNetCfg.nic1AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: true, }, @@ -4145,7 +4265,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on different NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: false, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4153,7 +4273,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNoRoute{}, dependentOnForwarding: false, }, @@ -4161,7 +4281,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with route on same NIC", netCfg: ipv6LinkLocalNIC1WithGlobalRemote, forwardingEnabled: true, - localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + localAddrWithPrefix: ipv6LinkLocalNIC1WithGlobalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4169,7 +4289,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4177,7 +4297,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and link-local local addr with route on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4185,7 +4305,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4193,7 +4313,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4201,7 +4321,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4209,7 +4329,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1AddrWithPrefix, findRouteErr: &tcpip.ErrNetworkUnreachable{}, dependentOnForwarding: false, }, @@ -4217,7 +4337,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: false, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4225,7 +4345,7 @@ func TestFindRouteWithForwarding(t *testing.T) { name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, forwardingEnabled: true, - localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + localAddrWithPrefix: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2AddrWithPrefix, findRouteErr: nil, dependentOnForwarding: false, }, @@ -4247,12 +4367,20 @@ func TestFindRouteWithForwarding(t *testing.T) { t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) } - if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + protocolAddr1 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic1AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID1, protocolAddr1, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID1, protocolAddr1, err) } - if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + protocolAddr2 := tcpip.ProtocolAddress{ + Protocol: test.netCfg.proto, + AddressWithPrefix: test.netCfg.nic2AddrWithPrefix, + } + if err := s.AddProtocolAddress(nicID2, protocolAddr2, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddr2, err) } if err := s.SetForwardingDefaultAndAllNICs(test.netCfg.proto, test.forwardingEnabled); err != nil { @@ -4261,20 +4389,20 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) - r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + r, err := s.FindRoute(test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { - t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, diff) + t.Fatalf("unexpected error from FindRoute(%d, %s, %s, %d, false), (-want, +got):\n%s", test.addrNIC, test.localAddrWithPrefix.Address, test.netCfg.remoteAddr, test.netCfg.proto, diff) } if test.findRouteErr != nil { return } - if r.LocalAddress() != test.localAddr { - t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddr) + if r.LocalAddress() != test.localAddrWithPrefix.Address { + t.Errorf("got r.LocalAddress() = %s, want = %s", r.LocalAddress(), test.localAddrWithPrefix.Address) } if r.RemoteAddress() != test.netCfg.remoteAddr { t.Errorf("got r.RemoteAddress() = %s, want = %s", r.RemoteAddress(), test.netCfg.remoteAddr) @@ -4297,8 +4425,8 @@ func TestFindRouteWithForwarding(t *testing.T) { if !ok { t.Fatal("packet not sent through ep2") } - if pkt.Route.LocalAddress != test.localAddr { - t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + if pkt.Route.LocalAddress != test.localAddrWithPrefix.Address { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddrWithPrefix.Address) } if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index e90c1a770..dc7289441 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/internal/tcp" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -380,15 +381,18 @@ type TCPSndBufState struct { // SndClosed indicates that the endpoint has been closed for sends. SndClosed bool - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - // PacketTooBigCount is used to notify the main protocol routine how // many times a "packet too big" control packet is received. PacketTooBigCount int // SndMTU is the smallest MTU seen in the control packets received. SndMTU int + + // AutoTuneSndBufDisabled indicates that the auto tuning of send buffer + // is disabled. + // + // Must be accessed using atomic operations. + AutoTuneSndBufDisabled uint32 } // TCPEndpointStateInner contains the members of TCPEndpointState used directly @@ -399,7 +403,7 @@ type TCPSndBufState struct { type TCPEndpointStateInner struct { // TSOffset is a randomized offset added to the value of the TSVal // field in the timestamp option. - TSOffset uint32 + TSOffset tcp.TSOffset // SACKPermitted is set to true if the peer sends the TCPSACKPermitted // option in the SYN/SYN-ACK. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 8a8454a6a..542d9257c 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -31,11 +32,13 @@ type protocolIDs struct { // transportEndpoints manages all endpoints of a given protocol. It has its own // mutex so as to reduce interference between protocols. type transportEndpoints struct { - // mu protects all fields of the transportEndpoints. - mu sync.RWMutex + mu sync.RWMutex + // +checklocks:mu endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. + // + // +checklocks:mu rawEndpoints []RawTransportEndpoint } @@ -68,7 +71,7 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { // descending order of match quality. If a call to yield returns false, // iterEndpointsLocked stops iteration and returns immediately. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -109,7 +112,7 @@ func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in // descending order of match quality. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { var matchedEPs []*endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -121,7 +124,7 @@ func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) [] // findEndpointLocked returns the endpoint that most closely matches the given id. // -// Preconditions: eps.mu must be locked. +// +checklocks:eps.mu func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { var matchedEP *endpointsByNIC eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { @@ -132,10 +135,12 @@ func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpo } type endpointsByNIC struct { - mu sync.RWMutex - endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 + + mu sync.RWMutex + // +checklocks:mu + endpoints map[tcpip.NICID]*multiPortEndpoint } func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { @@ -170,7 +175,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet return true } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNIC.seed) + transEP := mpep.selectEndpoint(id, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() @@ -199,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, tran // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) + mpep.selectEndpoint(id, epsByNIC.seed).HandleError(transErr, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -215,10 +220,17 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t netProto: netProto, transProto: transProto, } - epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, flags) + if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { + return err + } + // Only add this newly created multiportEndpoint if the singleRegisterEndpoint + // succeeded. + if !ok { + epsByNIC.endpoints[bindToDevice] = multiPortEp + } + return nil } func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { @@ -325,15 +337,18 @@ func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` demux *transportDemuxer netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber + flags ports.FlagCounter + + mu sync.RWMutex `state:"nosave"` // endpoints stores the transport endpoints in the order in which they // were bound. This is required for UDP SO_REUSEADDR. + // + // +checklocks:mu endpoints []TransportEndpoint - flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -354,13 +369,16 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpoints) == 1 { - return mpep.endpoints[0] +func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint { + ep.mu.RLock() + defer ep.mu.RUnlock() + + if len(ep.endpoints) == 1 { + return ep.endpoints[0] } - if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { - return mpep.endpoints[len(mpep.endpoints)-1] + if ep.flags.SharedFlags().ToFlags().Effective().MostRecent { + return ep.endpoints[len(ep.endpoints)-1] } payload := []byte{ @@ -376,8 +394,8 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) - return mpep.endpoints[idx] + idx := reciprocalScale(hash, uint32(len(ep.endpoints))) + return ep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { @@ -405,7 +423,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { @@ -468,17 +485,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNIC, ok := eps.endpoints[id] if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: d.stack.Seed(), + seed: d.stack.seed, } + } + if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { + return err + } + // Only add this newly created epsByNIC if registerEndpoint succeeded. + if !ok { eps.endpoints[id] = epsByNIC } - - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) + return nil } func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { @@ -646,7 +667,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN } } - ep := selectEndpoint(id, mpep, epsByNIC.seed) + ep := mpep.selectEndpoint(id, epsByNIC.seed) epsByNIC.mu.RUnlock() return ep } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 0972c94de..cd3a8c25a 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -35,7 +35,7 @@ import ( const ( testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testDstAddrV6 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") testSrcAddrV4 = "\x0a\x00\x00\x01" testDstAddrV4 = "\x0a\x00\x00\x02" @@ -64,12 +64,20 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { - t.Fatalf("AddAddress IPv4 failed: %s", err) + protocolAddrV4 := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.Address(testDstAddrV4).WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { - t.Fatalf("AddAddress IPv6 failed: %s", err) + protocolAddrV6 := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: testDstAddrV6.WithPrefix(), + } + if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err) } } @@ -203,6 +211,56 @@ func TestTransportDemuxerRegister(t *testing.T) { } } +func TestTransportDemuxerRegisterMultiple(t *testing.T) { + type test struct { + flags ports.Flags + want tcpip.Error + } + for _, subtest := range []struct { + name string + tests []test + }{ + {"zeroFlags", []test{ + {ports.Flags{}, nil}, + {ports.Flags{}, &tcpip.ErrPortInUse{}}, + }}, + {"multibindFlags", []test{ + // Allow multiple registrations same TransportEndpointID with multibind flags. + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + // Disallow registration w/same ID for a non-multibindflag. + {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, + }}, + } { + t.Run(subtest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + var eps []tcpip.Endpoint + for idx, test := range subtest.tests { + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + eps = append(eps, ep) + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + id := stack.TransportEndpointID{LocalPort: 1} + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { + t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) + } + } + for _, ep := range eps { + ep.Close() + } + }) + } +} + // TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 839178809..655931715 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -357,8 +357,15 @@ func TestTransportReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -428,8 +435,15 @@ func TestTransportControlReceive(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } // Create endpoint and connect to remote address. @@ -497,8 +511,15 @@ func TestTransportSend(t *testing.T) { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress failed: %v", err) + protocolAddr := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: fakeDefaultPrefixLen, + }, + } + if err := s.AddProtocolAddress(1, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", 1, protocolAddr, err) } { |