diff options
Diffstat (limited to 'pkg/tcpip/stack')
25 files changed, 1754 insertions, 1769 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..e0847e58a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", @@ -94,7 +95,7 @@ go_library( go_test( name = "stack_x_test", - size = "medium", + size = "small", srcs = [ "addressable_endpoint_state_test.go", "ndp_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5720e7543..782e74b24 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -35,7 +35,6 @@ import ( // Currently, only TCP tracking is supported. // Our hash table has 16K buckets. -// TODO(gvisor.dev/issue/170): These should be tunable. const numBuckets = 1 << 14 // Direction of the tuple. @@ -125,6 +124,8 @@ type conn struct { tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. + // + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. lastUsed time.Time `state:".(unixTime)"` } @@ -163,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. - // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle - // other tcp states. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) } else if hook == cn.tcbHook { @@ -244,8 +243,7 @@ func (ct *ConnTrack) init() { // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support -// other transport protocols. +// TODO(gvisor.dev/issue/6168): Support UDP. func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { tid, err := packetToTupleID(pkt) if err != nil { @@ -383,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/170): Support other transport protocols. + // TODO(gvisor.dev/issue/6168): Support UDP. if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -407,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. + var newAddr tcpip.Address + var newPort uint16 + + updateSRCFields := false + switch hook { case Prerouting, Output: if conn.manip == manipDestination { switch dir { case dirOriginal: - tcpHeader.SetDestinationPort(conn.reply.srcPort) - netHeader.SetDestinationAddress(conn.reply.srcAddr) + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr case dirReply: - tcpHeader.SetSourcePort(conn.original.dstPort) - netHeader.SetSourceAddress(conn.original.dstAddr) + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + + updateSRCFields = true } pkt.NatDone = true } @@ -424,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if conn.manip == manipSource { switch dir { case dirOriginal: - tcpHeader.SetSourcePort(conn.reply.dstPort) - netHeader.SetSourceAddress(conn.reply.dstAddr) + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + + updateSRCFields = true case dirReply: - tcpHeader.SetDestinationPort(conn.original.srcPort) - netHeader.SetDestinationAddress(conn.original.srcAddr) + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr } pkt.NatDone = true } @@ -439,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } + fullChecksum := false + updatePseudoHeader := false switch hook { case Prerouting, Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - tcpHeader.SetChecksum(xsum) + updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + fullChecksum = true + updatePseudoHeader = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } + rewritePacket( + netHeader, + tcpHeader, + updateSRCFields, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) // Update the state of tcb. - // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle - // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() @@ -542,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused returns the next bucket that should be checked and the time after // which it should be called again. func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { - // TODO(gvisor.dev/issue/170): This can be more finely controlled, as - // it is in Linux via sysctl. const fractionPerReaping = 128 const maxExpiredPct = 50 const maxFullTraversal = 60 * time.Second diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7107d598d..72f66441f 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -114,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -134,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { - time.AfterFunc(f.proto.addrResolveDelay, func() { + f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { fn(f.proto.neigh, addr, remoteLinkAddr) }) } @@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { + clock := faketime.NewManualClock() // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, + Clock: clock, }) protoNum := proto.Number() @@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ + ep1 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "a", @@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ + ep2 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "b", @@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) } - return ep1, ep2 + return clock, ep1, ep2 } func TestForwardingWithStaticResolver(t *testing.T) { @@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: default: @@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) { Data: buf.ToVectorisedView(), })) + clock.Advance(proto.addrResolveDelay) select { case <-ep2.C: t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): + default: } } @@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } // All packets should fail resolution. - // TODO(gvisor.dev/issue/5141): Use a fake clock. for i := 0; i < numPackets; i++ { + clock.Advance(proto.addrResolveDelay) select { case got := <-ep2.C: t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - case <-time.After(100 * time.Millisecond): + default: } } } @@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { } }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. @@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { for i := 0; i < maxPendingResolutions; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 3670d5995..f152c0d83 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { +func DefaultTables(seed uint32) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,7 @@ func DefaultTables() *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: generateRandUint32(), + seed: seed, }, reaperDone: make(chan struct{}, 1), } @@ -268,10 +268,6 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from -// which address can be gathered. Currently, address is only needed for -// prerouting. -// // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { @@ -371,6 +367,7 @@ func (it *IPTables) startReaper(interval time.Duration) { select { case <-it.reaperDone: return + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. case <-time.After(interval): bucket, interval = it.connections.reapUnused(bucket, interval) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 2812c89aa..96cc899bb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than // forwarded). -// -// TODO(gvisor.dev/issue/170): Other flags need to be added after we support -// them. type RedirectTarget struct { // Port indicates port used to redirect. It is immutable. Port uint16 @@ -100,9 +97,6 @@ type RedirectTarget struct { } // Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for Prerouting and calls pkt.Clone(), neither -// of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { @@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.Port) - // Calculate UDP checksum and set it. if hook == Output { - udpHeader.SetChecksum(0) - netHeader := pkt.Network() - netHeader.SetDestinationAddress(address) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + udpHeader, + false, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + rt.Port, + address, + ) + } else { + udpHeader.SetDestinationPort(rt.Port) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetChecksum(0) - udpHeader.SetSourcePort(st.Port) - netHeader := pkt.Network() - netHeader.SetSourceAddress(st.Addr) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + header.UDP(pkt.TransportHeader().View()), + true, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + st.Port, + st.Addr, + ) - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleAccept, 0 } + +func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { + if updateSRCFields { + if fullChecksum { + t.SetSourcePortWithChecksumUpdate(newPort) + } else { + t.SetSourcePort(newPort) + } + } else { + if fullChecksum { + t.SetDestinationPortWithChecksumUpdate(newPort) + } else { + t.SetDestinationPort(newPort) + } + } + + if updatePseudoHeader { + var oldAddr tcpip.Address + if updateSRCFields { + oldAddr = n.SourceAddress() + } else { + oldAddr = n.DestinationAddress() + } + + t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum) + } + + if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok { + if updateSRCFields { + checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr) + } else { + checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr) + } + } else if updateSRCFields { + n.SetSourceAddress(newAddr) + } else { + n.SetDestinationAddress(newAddr) + } +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 93592e7f5..66e5f22ac 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -242,7 +242,6 @@ type IPHeaderFilter struct { func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( - // TODO(gvisor.dev/issue/170): Support other filter fields. transProto tcpip.TransportProtocolNumber dstAddr tcpip.Address srcAddr tcpip.Address @@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return true case Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ac2fa777e..9623d9c28 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -16,14 +16,14 @@ package stack_test import ( "bytes" - "context" "encoding/binary" "fmt" + "math/rand" "testing" "time" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -52,17 +52,6 @@ const ( linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -112,11 +101,13 @@ type ndpDADEvent struct { res stack.DADResult } -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool +type ndpOffLinkRouteEvent struct { + nicID tcpip.NICID + subnet tcpip.Subnet + router tcpip.Address + prf header.NDPRoutePreference + // true if route was updated, false if invalidated. + updated bool } type ndpPrefixEvent struct { @@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct { eventType ndpAutoGenAddrEventType } +func (e ndpAutoGenAddrEvent) String() string { + return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType) +} + type ndpRDNSS struct { addrs []tcpip.Address lifetime time.Duration @@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // related events happen for test purposes. type ndpDispatcher struct { dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool + offLinkRouteC chan ndpOffLinkRouteEvent prefixC chan ndpPrefixEvent - rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent dnsslC chan ndpDNSSLEvent @@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add } } -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. +func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, true, } } - - return n.rememberRouter } -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. +func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + var prf header.NDPRoutePreference + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, false, } } } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ nicID, @@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip true, } } - - return n.rememberPrefix } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. @@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ nicID, @@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi newAddr, } } - return true } func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { @@ -497,8 +490,9 @@ func TestDADResolve(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ - Clock: clock, - SecureRNG: &secureRNG, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -605,7 +599,10 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) + p, ok := e.Read() + if !ok { + t.Fatal("packet didn't arrive") + } // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -731,11 +728,13 @@ func TestDADFail(t *testing.T) { dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -761,16 +760,17 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) @@ -839,11 +839,13 @@ func TestDADStop(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -861,15 +863,16 @@ func TestDADStop(t *testing.T) { test.stopFn(t, s) // Wait for DAD to fail (since the address was removed during DAD). + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") } if !test.skipFinalAddrCheck { @@ -920,10 +923,12 @@ func TestSetNDPConfigurations(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -1002,28 +1007,23 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatal(err) } - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + // Sleep until right before resolution to make sure the address didn't + // resolve on NIC(1) yet. + const delta = 1 + clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { t.Fatal(err) @@ -1032,10 +1032,13 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) +// raBuf returns a valid NDP Router Advertisement with options, router +// preference and DHCPv6 configurations specified. +func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + const flagsByte = 1 + const routerLifetimeOffset = 2 + + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) @@ -1043,19 +1046,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) + binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) // Populate the Managed Address flag field. if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= (1 << 7) + // The Managed Addresses flag field is the 7th bit of the flags byte. + raPayload[flagsByte] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) + // The Other Configurations flag field is the 6th bit of the flags byte. + raPayload[flagsByte] |= 1 << 6 } + // The Prf field is held in the flags byte. + raPayload[flagsByte] |= byte(prf) << 3 opts := ra.Options() opts.Serialize(optSer) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -1083,7 +1086,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) } // raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related @@ -1091,18 +1094,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { + return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) } // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { +func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } +// raBufWithPrf returns a valid NDP Router Advertisement with a preference. +// +// Note, raBufWithPrf does not populate any of the RA fields other than the +// Router Lifetime and Default Router Preference fields. +func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) +} + // raBufWithPI returns a valid NDP Router Advertisement with a single Prefix // Information option. // @@ -1162,7 +1173,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { config: func(enable bool) ipv6.NDPConfigurations { return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} }, - ra: raBuf(llAddr2, 1000), + ra: raBufSimple(llAddr2, 1000), }, { name: "No Prefix Discovery", @@ -1198,9 +1209,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle @@ -1270,8 +1281,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) default: } select { @@ -1297,10 +1308,8 @@ func boolToUint64(v bool) uint64 { return 0 } -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1333,56 +1342,15 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestRouterDiscovery(t *testing.T) { + const nicID = 1 + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1391,30 +1359,33 @@ func TestRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { + expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { t.Helper() select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") } } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + var prf header.NDPRoutePreference + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for router discovery event") } } @@ -1423,37 +1394,44 @@ func TestRouterDiscovery(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Rx an RA from lladdr2 with zero lifetime. It should not be // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") default: } - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from lladdr2 with a huge lifetime and reserved preference value + // (which should be interpreted as the default (medium) preference value). + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - // Rx an RA from another router (lladdr3) with non-zero lifetime. + // Rx an RA from another router (lladdr3) with non-zero lifetime and + // non-default preference value. const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) + expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - // Rx an RA from lladdr2 with lesser lifetime. + // Rx an RA from lladdr2 with lesser lifetime and default (medium) + // preference value. const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds)) select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") default: } + // Rx an RA from lladdr2 with a different preference. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. @@ -1461,15 +1439,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1478,16 +1456,17 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) }) } // TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1500,23 +1479,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { })}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - if i <= ipv6.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredOffLinkRoutes { select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") @@ -1524,7 +1503,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } else { select { - case <-ndpDisp.routerC: + case <-ndpDisp.offLinkRouteC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") default: } @@ -1538,51 +1517,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) } -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestPrefixDiscovery(t *testing.T) { prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") @@ -1590,10 +1524,10 @@ func TestPrefixDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1602,6 +1536,7 @@ func TestPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1662,12 +1597,13 @@ func TestPrefixDiscovery(t *testing.T) { // Wait for prefix2's most recent invalidation job plus some buffer to // expire. + clock.Advance(time.Duration(lifetime) * time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1678,17 +1614,6 @@ func TestPrefixDiscovery(t *testing.T) { } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - prefix := tcpip.AddressWithPrefix{ Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, @@ -1696,10 +1621,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1708,6 +1633,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1729,46 +1655,39 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) + clock.Advance(header.NDPInfiniteLifetime - time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(testInfiniteLifetime): + default: t.Fatal("timed out waiting for prefix discovery event") } // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with 0 lifetime. @@ -1781,8 +1700,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1859,17 +1777,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) } +const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second) +const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second) + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1878,6 +1791,7 @@ func TestAutoGenAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1886,6 +1800,7 @@ func TestAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { @@ -1935,8 +1850,9 @@ func TestAutoGenAddr(t *testing.T) { default: } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -1946,7 +1862,7 @@ func TestAutoGenAddr(t *testing.T) { } // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") @@ -1954,12 +1870,13 @@ func TestAutoGenAddr(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1989,20 +1906,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t // TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when // configured to do so as part of IPv6 Privacy Extensions. func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond + const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -2022,218 +1926,211 @@ func TestAutoGenTempAddr(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + for i, test := range tests { + t.Run(test.name, func(t *testing.T) { + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, + Clock: clock, + }) - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("expected addr auto gen event") } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for DAD event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum and won't be reached in this test. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 } - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) + default: + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } } // TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not @@ -2241,12 +2138,6 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - tests := []struct { name string dupAddrTransmits uint8 @@ -2262,66 +2153,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, + })}, + Clock: clock, + }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") + default: + t.Fatal("expected addr auto gen event") + } + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD event") + } - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + default: + } + }) + } } // TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address @@ -2334,12 +2215,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) @@ -2350,6 +2225,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -2363,6 +2239,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2392,12 +2269,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // Wait for DAD to complete for the stable address then expect the temporary // address to be generated. + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } select { @@ -2405,7 +2283,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -2414,46 +2292,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // regenerated. func TestAutoGenTempAddrRegen(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrValidLifetime: maxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate, + } + clock := faketime.NewManualClock() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2476,36 +2352,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" { t.Fatal(mismatch) } + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false @@ -2516,45 +2399,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) { ndpEP.SetNDPConfigurations(ndpConfigs) } + // Refresh lifetimes and wait for the last temporary address to be deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv) + + // Refresh lifetimes such that the prefix is valid and preferred forever. + // + // This should not affect the lifetimes of temporary addresses because they + // are capped by the maximum valid and preferred lifetimes for temporary + // addresses. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) + // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter + invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{}) for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter + expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter) + invalidateAfter = tempAddrRegenenerationTime } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" { t.Fatal(mismatch) } } @@ -2563,52 +2425,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // regeneration job gets updated when refreshing the address's lifetimes. func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate + maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second) + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime, + MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2, + } + clock := faketime.NewManualClock() + initialTime := clock.NowMonotonic() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -2625,22 +2489,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2648,13 +2513,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // // A new temporary address should be generated after the regeneration // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0)) expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: + } + + effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + + // Advance the clock by the regeneration time but don't expect a new temporary + // address as the prefix is deprecated. + clock.Advance(tempAddrRegenenerationTime) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: } // Prefer the prefix again. @@ -2662,8 +2541,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) + expectAutoGenAddrEvent(tempAddrs[1], newAddr) + // Wait for the first temporary address to be deprecated. + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %s", e) + default: + } // Increase the maximum lifetimes for temporary addresses to large values // then refresh the lifetimes of the prefix. @@ -2674,34 +2560,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // regenerate a new temporary address. Note, new addresses are only // regenerated after the preferred lifetime - the regenerate advance duration // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + const largeLifetimeSeconds = minVLSeconds * 2 + const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second + ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime + ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } ndpEP := ipv6Ep.(ipv6.NDPEndpoint) ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime) + clock.Advance(largeLifetime - timeSinceInitialTime) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) + // to offset the advement of time to test the first temporary address's + // deprecation after the second was generated + advLess := regenAdv + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv)) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + default: } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2929,13 +2811,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2945,6 +2828,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2958,7 +2842,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } - return ndpDisp, e, s + return ndpDisp, e, s, clock } // addrForNewConnectionTo returns the local address used when creating a new @@ -3032,7 +2916,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3135,19 +3019,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // when its preferred lifetime expires. func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3165,12 +3041,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -3188,7 +3065,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) @@ -3207,7 +3084,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3216,7 +3093,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3226,6 +3103,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since addr1 is deprecated but // addr2 is not. expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { @@ -3234,7 +3112,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3246,7 +3124,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3256,7 +3134,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3270,7 +3148,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3280,7 +3158,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr2) // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3292,6 +3170,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // cases because the deprecation and invalidation handlers could be handled in // either deprecation then invalidation, or invalidation then deprecation // (which should be cancelled by the invalidation handler). + // + // Since we're about to cause both events to fire, we need the dispatcher + // channel to be able to hold both. + if got, want := len(ndpDisp.autoGenAddrC), 0; got != want { + t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want { + t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2) + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { @@ -3302,21 +3191,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation + // If we get an invalidation event first, we should not get a deprecation // event after. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3353,15 +3242,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // infinite values. func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3385,68 +3265,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + default: + t.Fatal("expected addr auto gen event") + } - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an @@ -3454,12 +3324,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3470,137 +3334,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { evl uint32 }{ // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than + // new VL is less than minVLSeconds but was originally greater than // it. { "LargeVLToVLLessThanMinVLForUpdate", 9999, 1, - newMinVL, + minVLSeconds, }, { "LargeVLTo0", 9999, 0, - newMinVL, + minVLSeconds, }, { "InfiniteVLToVLLessThanMinVLForUpdate", infiniteVL, 1, - newMinVL, + minVLSeconds, }, { "InfiniteVLTo0", infiniteVL, 0, - newMinVL, + minVLSeconds, }, - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. + // Should not update VL if original VL was less than minVLSeconds + // and the new VL is also less than minVLSeconds. { "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, + minVLSeconds - 1, + minVLSeconds - 3, + minVLSeconds - 1, }, // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. + // remaining time or is greater than minVLSeconds. { "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, + minVLSeconds + 5, + minVLSeconds + 3, + minVLSeconds + 3, }, { "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, + minVLSeconds - 3, + minVLSeconds - 1, + minVLSeconds - 1, }, { "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, + minVLSeconds - 3, + minVLSeconds + 1, + minVLSeconds + 1, }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // - // Validate that the VL for the address got set - // to test.evl. - // + // + // Validate that the VL for the address got set + // to test.evl. + // - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } + // The address should not be invalidated until the effective valid + // lifetime has passed. + const delta = 1 + clock.Advance(time.Duration(test.evl)*time.Second - delta) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + default: + } - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + // Wait for the invalidation event. + clock.Advance(delta) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -3613,6 +3469,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3621,6 +3478,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3654,10 +3512,11 @@ func TestAutoGenAddrRemoval(t *testing.T) { // Wait for the original valid lifetime to make sure the original job got // cancelled/cleaned up. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -3668,7 +3527,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3779,6 +3638,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3787,6 +3647,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3816,30 +3677,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Should not get an invalidation event after the PI's invalidation // time. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } +func makeSecretKey(t *testing.T) []byte { + secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) + n, err := cryptorand.Read(secretKey) + if err != nil { + t.Fatalf("cryptorand.Read(_): %s", err) + } + if l := len(secretKey); n != l { + t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) + } + return secretKey +} + // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const nicID = 1 const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + + secretKey := makeSecretKey(t) prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) @@ -3861,6 +3728,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3875,6 +3743,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3913,12 +3782,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(validLifetimeSecondPrefix1 * time.Second) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3937,22 +3807,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -3977,22 +3832,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { @@ -4003,15 +3860,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } } @@ -4022,7 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name string ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { @@ -4031,7 +3889,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { // Receive an RA with prefix1 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) return nil @@ -4045,7 +3903,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil }, addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { @@ -4059,14 +3917,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { header.InitialTempIID(tempIIDHistory, nil, nicID) // Generate a stable SLAAC address so temporary addresses will be // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) + expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -4078,14 +3936,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. t.Run(addrType.name, func(t *testing.T) { for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { @@ -4094,8 +3944,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrType := addrType t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), @@ -4103,6 +3951,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { e := channel.New(0, 1280, linkAddr1) ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4119,6 +3968,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4126,12 +3976,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:]) // Simulate DAD conflicts so the address is regenerated. for i := uint8(0); i < numFailures; i++ { addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) // Should not have any new addresses assigned to the NIC. if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { @@ -4141,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4151,7 +4001,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4163,8 +4013,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // an address after DAD resolves. if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4174,7 +4024,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4231,13 +4081,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrType := addrType t.Run(addrType.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4248,6 +4097,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { RetransmitTimer: retransmitTimer, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4292,7 +4142,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4309,15 +4159,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { const maxRetries = 1 const lifetimeSeconds = 5 - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4326,6 +4168,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -4345,6 +4188,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4375,7 +4219,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict after some time has passed. - time.Sleep(failureTimer) + clock.Advance(failureTimer) rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { @@ -4390,12 +4234,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Let the next address resolve. addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) expectAutoGenAddrEvent(addr, newAddr) + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } @@ -4409,6 +4254,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // // We expect either just the invalidation event or the deprecation event // followed by the invalidation event. + clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) select { case e := <-ndpDisp.autoGenAddrC: if e.eventType == deprecatedAddr { @@ -4421,7 +4267,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4429,7 +4275,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for auto gen addr event") } } @@ -4691,11 +4537,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ @@ -4738,17 +4582,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ), ) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) } select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) @@ -4770,8 +4614,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpected router event = %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpected off-link route event = %#v", e) default: } select { @@ -4857,12 +4701,11 @@ func TestCleanupNDPState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, @@ -4874,16 +4717,17 @@ func TestCleanupNDPState(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func() (bool, ndpRouterEvent) { + expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { select { - case e := <-ndpDisp.routerC: + case e := <-ndpDisp.offLinkRouteC: return true, e default: } - return false, ndpRouterEvent{} + return false, ndpOffLinkRouteEvent{} } expectPrefixEvent := func() (bool, ndpPrefixEvent) { @@ -4928,8 +4772,8 @@ func TestCleanupNDPState(t *testing.T) { // multiple addresses. e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -4939,8 +4783,8 @@ func TestCleanupNDPState(t *testing.T) { } e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -4950,8 +4794,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -4961,8 +4805,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -5003,14 +4847,14 @@ func TestCleanupNDPState(t *testing.T) { test.cleanupFn(t, s) // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) + gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() + ok, e := expectOffLinkRouteEvent() if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) break } - gotRouterEvents[e]++ + gotOffLinkRouteEvents[e]++ } gotPrefixEvents := make(map[ndpPrefixEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { @@ -5037,14 +4881,14 @@ func TestCleanupNDPState(t *testing.T) { t.FailNow() } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { + t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) } expectedPrefixEvents := map[ndpPrefixEvent]int{ {nicID: nicID1, prefix: subnet1, discovered: false}: 1, @@ -5106,10 +4950,10 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) + clock.Advance(lifetimeSeconds * time.Second) select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") + case <-ndpDisp.offLinkRouteC: + t.Error("unexpected off-link route event") default: } select { @@ -5134,7 +4978,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { ndpDisp := ndpDispatcher{ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -5236,6 +5079,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectNoDHCPv6Event() } +var _ rand.Source = (*savingRandSource)(nil) + +type savingRandSource struct { + s rand.Source + + lastInt63 int64 +} + +func (d *savingRandSource) Int63() int64 { + i := d.s.Int63() + d.lastInt63 = i + return i +} +func (d *savingRandSource) Seed(seed int64) { + d.s.Seed(seed) +} + // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { @@ -5402,6 +5262,9 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("unexpectedly got a packet = %#v", p) } } + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), + } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5411,8 +5274,10 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, })}, - Clock: clock, + Clock: clock, + RandSource: &randSource, }) + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -5425,19 +5290,27 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) + if remaining != 0 { + maxRtrSolicitDelay := test.maxRtrSolicitDelay + if maxRtrSolicitDelay < 0 { + maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay + } + var actualRtrSolicitDelay time.Duration + if maxRtrSolicitDelay != 0 { + actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay + } + waitForPkt(actualRtrSolicitDelay) remaining-- } subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + for ; remaining != 0; remaining-- { + if test.effectiveRtrSolicitInt != 0 { waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) waitForPkt(time.Nanosecond) } else { - waitForPkt(test.effectiveRtrSolicitInt) + waitForPkt(0) } } @@ -5533,12 +5406,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { + waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) + clock.Advance(timeout) + p, ok := e.Read() if !ok { t.Fatal("timed out waiting for packet") } @@ -5552,6 +5424,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS()) } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5561,6 +5434,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { MaxRtrSolicitationDelay: delay, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5568,13 +5442,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok = e.Read(); ok { t.Fatal("should not have sent more than one RS message") } } @@ -5582,9 +5454,8 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } @@ -5595,21 +5466,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + waitForPkt(clock, delay) + waitForPkt(clock, interval) + waitForPkt(clock, interval) + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") } // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") } }) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 509f5ce5c..08857e1a9 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { *n = neighborCache{ nic: nic, - state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator), linkRes: r, } n.mu.Lock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..7de25fe37 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }) } @@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }) @@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now().Add(-durationReachableNanos), } wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } @@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - startedAt := clock.NowNanoseconds() + startedAt := clock.Now() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { State: Reachable, // Can be inferred since the frequently used entry is the first to // be created and transitioned to Reachable. - UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), + UpdatedAt: startedAt.Add(typicalLatency), }, } @@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { if !ok { t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { { wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { @@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..0a59eecdd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -31,10 +31,10 @@ const ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAtNanos int64 + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState * // calling `setStateLocked`. func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: cache.nic.stack.clock.Now(), } n := &neighborEntry{ cache: cache, @@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) { if ch := e.mu.done; ch != nil { close(ch) e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current + // Dequeue the pending packets asynchronously to not hold up the current // goroutine as writing packets may be a costly operation. // // At the time of writing, when writing packets, a neighbor's link address // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + // link resolution queue's lock. Dequeuing packets asynchronously avoids a + // lock ordering violation. + // + // NB: this is equivalent to spawning a goroutine directly using the go + // keyword but allows tests that use manual clocks to deterministically + // wait for this work to complete. + e.cache.nic.stack.clock.AfterFunc(0, func() { + e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + }) } } @@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() e.dispatchRemoveEventLocked() e.cancelTimerLocked() // TODO(https://gvisor.dev/issues/5583): test the case where this function is @@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.mu.neigh.State e.mu.neigh.State = next - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() config := e.nudState.Config() switch next { @@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Unknown, Unreachable: prev := e.mu.neigh.State e.mu.neigh.State = Incomplete - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() switch prev { case Unknown: e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..59d86d6d4 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ @@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes * EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { // UpdatedAt should remain the same during address resolution. e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAtNanos + startedAt := e.mu.neigh.UpdatedAt e.mu.Unlock() // Wait for the rest of the reachability probe transmissions, signifying @@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want { + if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + UpdatedAt: clock.Now(), }, }, } @@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + UpdatedAt: clock.Now(), }, }, } @@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: linkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dbba2c79f..b854d868c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,41 +773,35 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } transProto := state.proto - // Raw socket packets are delivered based solely on the transport - // protocol number. We do not inspect the payload to ensure it's - // validly formed. - n.stack.demux.deliverRawPacket(protocol, pkt) - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +833,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable @@ -891,6 +872,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } } +// DeliverRawPacket implements TransportDispatcher. +func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { + // For ICMPv4 only we validate the header length for compatibility with + // raw(7) ICMP_FILTER. The same check is made in Linux here: + // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189. + if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize { + return + } + n.stack.demux.deliverRawPacket(protocol, pkt) +} + // ID implements NetworkInterface. func (n *nic) ID() tcpip.NICID { return n.id diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..ca9822bca 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -16,6 +16,7 @@ package stack import ( "math" + "math/rand" "sync" "time" @@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 { return defaultMaxRandomFactor } -// A Rand is a source of random numbers. -type Rand interface { - // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). - Float32() float32 -} - // NUDState stores states needed for calculating reachable time. type NUDState struct { - rng Rand + clock tcpip.Clock + rng *rand.Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex + mu struct { + sync.RWMutex - config NUDConfigurations + config NUDConfigurations - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified // random number generator for use in recomputing ReachableTime. -func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { +func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState { s := &NUDState{ - rng: rng, + clock: clock, + rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if s.clock.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = s.clock.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..1aeb2f8a5 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -16,6 +16,7 @@ package stack_test import ( "math" + "math/rand" "testing" "time" @@ -28,17 +29,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) @@ -48,12 +47,14 @@ type fakeRand struct { num float32 } -var _ stack.Rand = (*fakeRand)(nil) +var _ rand.Source = (*fakeRand)(nil) -func (f *fakeRand) Float32() float32 { - return f.num +func (f *fakeRand) Int63() int64 { + return int64(f.num * float32(1<<63)) } +func (*fakeRand) Seed(int64) {} + func TestNUDFunctions(t *testing.T) { const nicID = 1 @@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) { t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) } else if test.expectedErr == nil { if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Errorf("neighbors mismatch (-want +got):\n%s", diff) @@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) if got, want := s.ReachableTime(), test.want; got != want { t.Errorf("got ReachableTime = %q, want = %q", got, want) } @@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) old := s.ReachableTime() if got, want := s.ReachableTime(), old; got != want { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 01652fbe7..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification @@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int { func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] if h.length > 0 { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) } if pk.pushed+size > pk.reserved { - panic("not enough headroom reserved") + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } pk.pushed += size h.offset = -pk.pushed diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go index 421fb5c15..c8294eb6e 100644 --- a/pkg/tcpip/stack/rand.go +++ b/pkg/tcpip/stack/rand.go @@ -15,7 +15,7 @@ package stack import ( - mathrand "math/rand" + "math/rand" "gvisor.dev/gvisor/pkg/sync" ) @@ -23,7 +23,7 @@ import ( // lockedRandomSource provides a threadsafe rand.Source. type lockedRandomSource struct { mu sync.Mutex - src mathrand.Source + src rand.Source } func (r *lockedRandomSource) Int63() (n int64) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 85bb87b4b..dfe2c886f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -265,6 +265,11 @@ type TransportDispatcher interface { // // DeliverTransportError takes ownership of the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) + + // DeliverRawPacket delivers a packet to any subscribed raw sockets. + // + // DeliverRawPacket does NOT take ownership of the packet buffer. + DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -420,7 +425,7 @@ const ( PermanentExpired // Temporary is an endpoint, created on a one-off basis to temporarily - // consider the NIC bound an an address that it is not explictiy bound to + // consider the NIC bound an an address that it is not explicitly bound to // (such as a permanent address). Its reference count must not be biased by 1 // so that the address is removed immediately when references to it are no // longer held. @@ -630,7 +635,7 @@ type NetworkEndpoint interface { // HandlePacket takes ownership of pkt. HandlePacket(pkt *PacketBuffer) - // Close is called when the endpoint is reomved from a stack. + // Close is called when the endpoint is removed from a stack. Close() // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for @@ -968,7 +973,7 @@ type DuplicateAddressDetector interface { // called with the result of the original DAD request. CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition - // SetDADConfiguations sets the configurations for DAD. + // SetDADConfigurations sets the configurations for DAD. SetDADConfigurations(c DADConfigurations) // DuplicateAddressProtocol returns the network protocol the receiver can @@ -979,7 +984,7 @@ type DuplicateAddressDetector interface { // LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target - // address. The request is broadcasted on the local network if a remote link + // address. The request is broadcast on the local network if a remote link // address is not provided. LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error @@ -1072,4 +1077,4 @@ type GSOEndpoint interface { // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. // This isn't a hard limit, because it is never set into packet headers. -const SoftwareGSOMaxSize = (1 << 16) +const SoftwareGSOMaxSize = 1 << 16 diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8814f45a6..81fabe29a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,17 +20,16 @@ package stack import ( - "bytes" "encoding/binary" "fmt" "io" - mathrand "math/rand" + "math/rand" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/atomicbitops" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -40,13 +39,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -116,7 +108,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - // TODO(gvisor.dev/issue/170): S/R this field. + // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -145,7 +137,7 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. - randomGenerator *mathrand.Rand + randomGenerator *rand.Rand // secureRNG is a cryptographically secure random number generator. secureRNG io.Reader @@ -196,9 +188,9 @@ type Options struct { // TransportProtocols lists the transport protocols to enable. TransportProtocols []TransportProtocolFactory - // Clock is an optional clock source used for timestampping packets. + // Clock is an optional clock used for timekeeping. // - // If no Clock is specified, the clock source will be time.Now. + // If Clock is nil, tcpip.NewStdClock() will be used. Clock tcpip.Clock // Stats are optional statistic counters. @@ -225,15 +217,21 @@ type Options struct { // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data - // returned by rand.Read(). + // returned by the stack secure RNG. // // RandSource must be thread-safe. - RandSource mathrand.Source + RandSource rand.Source - // IPTables are the initial iptables rules. If nil, iptables will allow + // IPTables are the initial iptables rules. If nil, DefaultIPTables will be + // used to construct the initial iptables rules. // all traffic. IPTables *IPTables + // DefaultIPTables is an optional iptables rules constructor that is called + // if IPTables is nil. If both fields are nil, iptables will allow all + // traffic. + DefaultIPTables func(uint32) *IPTables + // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader } @@ -331,23 +329,32 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + if opts.SecureRNG == nil { + opts.SecureRNG = cryptorand.Reader + } + randSrc := opts.RandSource if randSrc == nil { - // Source provided by mathrand.NewSource is not thread-safe so + var v int64 + if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil { + panic(err) + } + // Source provided by rand.NewSource is not thread-safe so // we wrap it in a simple thread-safe version. - randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + randSrc = &lockedRandomSource{src: rand.NewSource(v)} } + randomGenerator := rand.New(randSrc) + seed := randomGenerator.Uint32() if opts.IPTables == nil { - opts.IPTables = DefaultTables() + if opts.DefaultIPTables == nil { + opts.DefaultIPTables = DefaultTables + } + opts.IPTables = opts.DefaultIPTables(seed) } opts.NUDConfigs.resetInvalidFields() - if opts.SecureRNG == nil { - opts.SecureRNG = rand.Reader - } - s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -360,11 +367,11 @@ func New(opts Options) *Stack { handleLocal: opts.HandleLocal, tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), + seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), + randomGenerator: randomGenerator, secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, @@ -804,7 +811,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -856,7 +863,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), @@ -1819,7 +1826,7 @@ func (s *Stack) Seed() uint32 { // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. -func (s *Stack) Rand() *mathrand.Rand { +func (s *Stack) Rand() *rand.Rand { return s.randomGenerator } @@ -1829,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader { return s.secureRNG } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - -func generateRandInt64() int64 { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - panic(err) - } - buf := bytes.NewReader(b) - var v int64 - if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { - panic(err) - } - return v -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() @@ -1886,9 +1872,8 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index 33824afd0..dfec4258a 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,78 +14,6 @@ package stack -import "time" - // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack - -// saveT is invoked by stateify. -func (t *TCPCubicState) saveT() unixTime { - return unixTime{t.T.Unix(), t.T.UnixNano()} -} - -// loadT is invoked by stateify. -func (t *TCPCubicState) loadT(unix unixTime) { - t.T = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (t *TCPRACKState) saveXmitTime() unixTime { - return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (t *TCPRACKState) loadXmitTime(unix unixTime) { - t.XmitTime = time.Unix(unix.second, unix.nano) -} - -// saveLastSendTime is invoked by stateify. -func (t *TCPSenderState) saveLastSendTime() unixTime { - return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (t *TCPSenderState) loadLastSendTime(unix unixTime) { - t.LastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) saveRTTMeasureTime() unixTime { - return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { - t.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.MeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { - return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { - r.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveSegTime is invoked by stateify. -func (t *TCPEndpointState) saveSegTime() unixTime { - return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} -} - -// loadSegTime is invoked by stateify. -func (t *TCPEndpointState) loadSegTime(unix unixTime) { - t.SegTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02d54d29b..21951d05a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } + + nicStats := s.NICInfo()[nicid].Stats - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -2603,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } dadConfigs := stack.DefaultDADConfigurations() + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, } e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2629,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { linkLocalAddr := header.LinkLocalAddr(linkAddr1) // Wait for DAD to resolve. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event // means something is wrong. t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { t.Fatal(err) @@ -3270,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -3280,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3324,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(dadTransmits * retransmitTimer) select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) @@ -3837,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3875,8 +3916,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4223,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4394,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index ddff6e2d6..90a8ba6cf 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -39,7 +39,7 @@ type TCPCubicState struct { WMax float64 // T is the time when the current congestion avoidance was entered. - T time.Time `state:".(unixTime)"` + T tcpip.MonotonicTime // TimeSinceLastCongestion denotes the time since the current // congestion avoidance was entered. @@ -78,7 +78,7 @@ type TCPCubicState struct { type TCPRACKState struct { // XmitTime is the transmission timestamp of the most recent // acknowledged segment. - XmitTime time.Time `state:".(unixTime)"` + XmitTime tcpip.MonotonicTime // EndSequence is the ending TCP sequence number of the most recent // acknowledged segment. @@ -216,7 +216,7 @@ type TCPRTTState struct { // +stateify savable type TCPSenderState struct { // LastSendTime is the timestamp at which we sent the last segment. - LastSendTime time.Time `state:".(unixTime)"` + LastSendTime tcpip.MonotonicTime // DupAckCount is the number of Duplicate ACKs received. It is used for // fast retransmit. @@ -256,7 +256,7 @@ type TCPSenderState struct { RTTMeasureSeqNum seqnum.Value // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Closed indicates that the caller has closed the endpoint for // sending. @@ -313,7 +313,7 @@ type TCPSACKInfo struct { type RcvBufAutoTuneParams struct { // MeasureTime is the time at which the current measurement was // started. - MeasureTime time.Time `state:".(unixTime)"` + MeasureTime tcpip.MonotonicTime // CopiedBytes is the number of bytes copied to user space since this // measure began. @@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct { // RTTMeasureTime is the absolute time at which the current RTT // measurement period began. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Disabled is true if an explicit receive buffer is set for the // endpoint. @@ -380,9 +380,6 @@ type TCPSndBufState struct { // SndClosed indicates that the endpoint has been closed for sends. SndClosed bool - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - // PacketTooBigCount is used to notify the main protocol routine how // many times a "packet too big" control packet is received. PacketTooBigCount int @@ -429,7 +426,7 @@ type TCPEndpointState struct { ID TCPEndpointID // SegTime denotes the absolute time when this segment was received. - SegTime time.Time `state:".(unixTime)"` + SegTime tcpip.MonotonicTime // RcvBufState contains information about the state of the endpoint's // receive socket buffer. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 80ad1a9d4..dda57e225 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -217,13 +216,20 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t netProto: netProto, transProto: transProto, } - epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, flags) + if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { + return err + } + // Only add this newly created multiportEndpoint if the singleRegisterEndpoint + // succeeded. + if !ok { + epsByNIC.endpoints[bindToDevice] = multiPortEp + } + return nil } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -407,7 +413,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { @@ -470,17 +475,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNIC, ok := eps.endpoints[id] if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + seed: d.stack.Seed(), } + } + if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { + return err + } + // Only add this newly created epsByNIC if registerEndpoint succeeded. + if !ok { eps.endpoints[id] = epsByNIC } - - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) + return nil } func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { @@ -502,7 +511,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum return nil } - return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) + return epsByNIC.checkEndpoint(flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4848495c9..45b09110d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -18,6 +18,7 @@ import ( "io/ioutil" "math" "math/rand" + "strconv" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -84,7 +85,8 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } type headers struct { - srcPort, dstPort uint16 + srcPort uint16 + dstPort uint16 } func newPayload() []byte { @@ -201,6 +203,56 @@ func TestTransportDemuxerRegister(t *testing.T) { } } +func TestTransportDemuxerRegisterMultiple(t *testing.T) { + type test struct { + flags ports.Flags + want tcpip.Error + } + for _, subtest := range []struct { + name string + tests []test + }{ + {"zeroFlags", []test{ + {ports.Flags{}, nil}, + {ports.Flags{}, &tcpip.ErrPortInUse{}}, + }}, + {"multibindFlags", []test{ + // Allow multiple registrations same TransportEndpointID with multibind flags. + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + // Disallow registration w/same ID for a non-multibindflag. + {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, + }}, + } { + t.Run(subtest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + var eps []tcpip.Endpoint + for idx, test := range subtest.tests { + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + eps = append(eps, ep) + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + id := stack.TransportEndpointID{LocalPort: 1} + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { + t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) + } + } + for _, ep := range eps { + ep.Close() + } + }) + } +} + // TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { @@ -208,7 +260,7 @@ func TestBindToDeviceDistribution(t *testing.T) { reuse bool bindToDevice tcpip.NICID } - for _, test := range []struct { + tcs := []struct { name string // endpoints will received the inject packets. endpoints []endpointSockopts @@ -217,29 +269,29 @@ func TestBindToDeviceDistribution(t *testing.T) { wantDistributions map[tcpip.NICID][]float64 }{ { - "BindPortReuse", + name: "BindPortReuse", // 5 endpoints that all have reuse set. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { - "BindToDevice", + name: "BindToDevice", // 3 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: false, bindToDevice: 1}, {reuse: false, bindToDevice: 2}, {reuse: false, bindToDevice: 3}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. @@ -249,9 +301,9 @@ func TestBindToDeviceDistribution(t *testing.T) { }, }, { - "ReuseAndBindToDevice", + name: "ReuseAndBindToDevice", // 6 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 2}, @@ -259,7 +311,7 @@ func TestBindToDeviceDistribution(t *testing.T) { {reuse: true, bindToDevice: 2}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. 1: {0.5, 0.5, 0, 0, 0, 0}, @@ -270,35 +322,42 @@ func TestBindToDeviceDistribution(t *testing.T) { 1000: {0, 0, 0, 0, 0, 1}, }, }, - } { - for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } { + } + protos := map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } + + for _, test := range tcs { + for protoName, protoNum := range protos { for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+string(device), func(t *testing.T) { + t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { + // Create the NICs. var devices []tcpip.NICID for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) + // Create endpoints and bind each to a NIC, sometimes reusing ports. eps := make(map[tcpip.Endpoint]int) - pollChannel := make(chan tcpip.Endpoint) for i, endpoint := range test.endpoints { // Try to receive the data. wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) + t.Cleanup(func() { + wq.EventUnregister(&we) + close(ch) + }) var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) } + t.Cleanup(ep.Close) eps[ep] = i go func(ep tcpip.Endpoint) { @@ -307,32 +366,34 @@ func TestBindToDeviceDistribution(t *testing.T) { } }(ep) - defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: dstAddr = testDstAddrV4 case ipv6.ProtocolNumber: dstAddr = testDstAddrV6 default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } - npackets := 100000 - nports := 10000 + // Send packets across a range of ports, checking that packets from + // the same source port are always demultiplexed to the same + // destination endpoint. + npackets := 10_000 + nports := 1_000 if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } - ports := make(map[uint16]tcpip.Endpoint) + endpoints := make(map[uint16]tcpip.Endpoint) stats := make(map[tcpip.Endpoint]int) for i := 0; i < npackets; i++ { // Send a packet. @@ -342,13 +403,13 @@ func TestBindToDeviceDistribution(t *testing.T) { srcPort: testSrcPort + port, dstPort: testDstPort, } - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: c.sendV4Packet(payload, hdrs, device) case ipv6.ProtocolNumber: c.sendV6Packet(payload, hdrs, device) default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } ep := <-pollChannel @@ -357,11 +418,11 @@ func TestBindToDeviceDistribution(t *testing.T) { } stats[ep]++ if i < nports { - ports[uint16(i)] = ep + endpoints[uint16(i)] = ep } else { // Check that all packets from one client are handled by the same // socket. - if want, got := ports[port], ep; want != got { + if want, got := endpoints[port], ep; want != got { t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) } } |