diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 123 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 1064 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 248 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 224 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 44 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 6 |
12 files changed, 1066 insertions, 696 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index cdb435644..3f083928f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -407,12 +407,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data.Size()) + length := uint16(len(tcpHeader) + pkt.Data().Size()) xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d63e9757c..0e8b90c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumVV(pkt.Data, xsum) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 740bdac28..47796a6ba 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -99,12 +99,11 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi } // ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionStatus. +// ndpDispatcher.OnDuplicateAddressDetectionResult. type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - resolved bool - err tcpip.Error + nicID tcpip.NICID + addr tcpip.Address + res stack.DADResult } type ndpRouterEvent struct { @@ -173,14 +172,13 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult. +func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, addr, - resolved, - err, + res, } } } @@ -311,8 +309,8 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string { + return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e)) } // TestDADDisabled tests that an address successfully resolves immediately @@ -344,8 +342,8 @@ func TestDADDisabled(t *testing.T) { // DAD on it. select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -491,8 +489,8 @@ func TestDADResolve(t *testing.T) { case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { @@ -573,7 +571,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: header.IPv6Any, + Dst: snmc, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -594,9 +596,10 @@ func TestDADFail(t *testing.T) { const nicID = 1 tests := []struct { - name string - rxPkt func(e *channel.Endpoint, tgt tcpip.Address) - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + name string + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) + getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + expectedHolderLinkAddress tcpip.LinkAddress }{ { name: "RxSolicit", @@ -604,6 +607,7 @@ func TestDADFail(t *testing.T) { getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, + expectedHolderLinkAddress: "", }, { name: "RxAdvert", @@ -619,7 +623,11 @@ func TestDADFail(t *testing.T) { na.Options().Serialize(header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(linkAddr1), }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: tgt, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -634,6 +642,7 @@ func TestDADFail(t *testing.T) { getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, + expectedHolderLinkAddress: linkAddr1, }, } @@ -683,8 +692,8 @@ func TestDADFail(t *testing.T) { // something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { @@ -782,8 +791,8 @@ func TestDADStop(t *testing.T) { // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } @@ -844,8 +853,8 @@ func TestSetNDPConfigurations(t *testing.T) { expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatalf("expected DAD event for %s", addr) @@ -936,8 +945,8 @@ func TestSetNDPConfigurations(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { @@ -973,7 +982,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo } opts := ra.Options() opts.Serialize(optSer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: pkt, + Src: ip, + Dst: header.IPv6AllNodesMulticastAddress, + })) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ @@ -1951,8 +1964,8 @@ func TestAutoGenTempAddr(t *testing.T) { select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2157,8 +2170,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2245,8 +2258,8 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // address to be generated. select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2711,8 +2724,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { t.Helper() clock.Advance(dupAddrTransmits * retransmitTimer) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } @@ -2742,8 +2755,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -3841,26 +3854,26 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) { + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, res); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, res); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -3917,7 +3930,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -3992,7 +4005,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, false, nil) + expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4002,7 +4015,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{}) + expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4015,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, true) + expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4132,8 +4145,8 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -4231,8 +4244,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -4243,8 +4256,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index afff1b434..48bb75e2f 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -59,21 +59,24 @@ const ( infiniteDuration = time.Duration(math.MaxInt64) ) -// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAtNanos field is ignored due to a lack of a -// deterministic method to predict the time that an event will be dispatched. -func entryDiffOpts() []cmp.Option { +// unorderedEventsDiffOpts returns options passed to cmp.Diff to sort slices of +// events for cases where ordering must be ignored. +func unorderedEventsDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + }), } } -// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to -// sort slices of entries for cases where ordering must be ignored. -func entryDiffOptsWithSort() []cmp.Option { - return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - })) +// unorderedEntriesDiffOpts returns options passed to cmp.Diff to sort slices of +// entries for cases where ordering must be ignored. +func unorderedEntriesDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } } func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *testNeighborResolver { @@ -280,48 +283,105 @@ func TestNeighborCacheSetConfig(t *testing.T) { } } -func TestNeighborCacheEntry(t *testing.T) { - c := DefaultNUDConfigurations() - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(&nudDisp, c, clock) +func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry, removed []NeighborEntry) error { + var gotLinkResolutionResult LinkResolutionResult - entry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { + gotLinkResolutionResult = r + }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + return fmt.Errorf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } - clock.Advance(typicalLatency) + { + var wantEvents []testEntryEventInfo - wantEvents := []testEntryEventInfo{ - { + for _, removedEntry := range removed { + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAtNanos: clock.NowNanoseconds(), }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + }) + + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + } + + clock.Advance(typicalLatency) + + select { + case <-ch: + default: + return fmt.Errorf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) + } + wantLinkResolutionResult := LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil} + if diff := cmp.Diff(wantLinkResolutionResult, gotLinkResolutionResult); diff != "" { + return fmt.Errorf("got link resolution result mismatch (-want +got):\n%s", diff) + } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, }, - }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + + return nil +} + +func addReachableEntry(nudDisp *testNUDDispatcher, clock *faketime.ManualClock, linkRes *testNeighborResolver, entry NeighborEntry) error { + return addReachableEntryWithRemoved(nudDisp, clock, linkRes, entry, nil /* removed */) +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := faketime.NewManualClock() + linkRes := newTestNeighborResolver(&nudDisp, c, clock) + + entry, ok := linkRes.entries.entry(0) + if !ok { + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") + } + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } if _, _, err := linkRes.neigh.entry(entry.Addr, "", nil); err != nil { @@ -345,41 +405,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } - - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } linkRes.neigh.removeEntry(entry.Addr) @@ -390,14 +419,15 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, } nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.mu.events) nudDisp.mu.Unlock() if diff != "" { t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) @@ -439,18 +469,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { // Fill the neighbor cache to capacity to verify the LRU eviction strategy is // working properly after the entry removal. for i := opts.startAtEntryIndex; i < c.linkRes.entries.size(); i++ { - // Add a new entry - entry, ok := c.linkRes.entries.entry(i) - if !ok { - return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - return fmt.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) - - var wantEvents []testEntryEventInfo + var removedEntries []NeighborEntry // When beyond the full capacity, the cache will evict an entry as per the // LRU eviction strategy. Note that the number of static entries should not @@ -458,63 +477,40 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if i >= neighborCacheSize+opts.startAtEntryIndex { removedEntry, ok := c.linkRes.entries.entry(i - neighborCacheSize) if !ok { - return fmt.Errorf("linkRes.entries.entry(%d) not found", i-neighborCacheSize) + return fmt.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize) } - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - }, - }) + removedEntries = append(removedEntries, removedEntry) } - wantEvents = append(wantEvents, testEntryEventInfo{ - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, testEntryEventInfo{ - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }) - - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - return fmt.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + entry, ok := c.linkRes.entries.entry(i) + if !ok { + return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) + } + if err := addReachableEntryWithRemoved(c.nudDisp, c.clock, c.linkRes, entry, removedEntries); err != nil { + return fmt.Errorf("addReachableEntryWithRemoved(...) = %s", err) } } // Expect to find only the most recent entries. The order of entries reported // by entries() is nondeterministic, so entries have to be sorted before // comparison. - wantUnsortedEntries := opts.wantStaticEntries + wantUnorderedEntries := opts.wantStaticEntries for i := c.linkRes.entries.size() - neighborCacheSize; i < c.linkRes.entries.size(); i++ { entry, ok := c.linkRes.entries.entry(i) if !ok { - return fmt.Errorf("c.linkRes.entries.entry(%d) not found", i) + return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } + durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } - if diff := cmp.Diff(wantUnsortedEntries, c.linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnorderedEntries, c.linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { return fmt.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -560,38 +556,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(c.linkRes.neigh.config().RetransmitTimer) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Remove the entry @@ -603,14 +571,15 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -636,33 +605,36 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } - // Remove the static entry that was just added + // Add a duplicate static entry with the same link address. c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - // No more events should have been dispatched. c.nudDisp.mu.Lock() defer c.nudDisp.mu.Unlock() if diff := cmp.Diff([]testEntryEventInfo(nil), c.nudDisp.mu.events); diff != "" { @@ -680,48 +652,56 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } // Add a duplicate entry with a different link address staticLinkAddr += "duplicate" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { wantEvents := []testEntryEventInfo{ { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - defer c.nudDisp.mu.Unlock() - if diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } } @@ -742,45 +722,51 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } // Remove the static entry that was just added c.linkRes.neigh.removeEntry(entry.Addr) + { wantEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -812,66 +798,41 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(typicalLatency) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Override the entry with a static one using the same address staticLinkAddr := entry.LinkAddr + "static" c.linkRes.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { wantEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -883,9 +844,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } @@ -905,7 +867,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } c.linkRes.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) @@ -913,40 +875,45 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), + }, }, - }, - } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) + c.nudDisp.mu.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } opts := overflowOptions{ startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } @@ -965,39 +932,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } - _, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Add a static entry. @@ -1009,14 +947,15 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, } nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, nudDisp.mu.events) nudDisp.mu.events = nil nudDisp.mu.Unlock() if diff != "" { @@ -1028,30 +967,32 @@ func TestNeighborCacheClear(t *testing.T) { linkRes.neigh.clear() // Remove events dispatched from clear() have no deterministic order so they - // need to be sorted beforehand. - wantUnsortedEvents := []testEntryEventInfo{ + // need to be sorted before comparison. + wantUnorderedEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAtNanos: clock.NowNanoseconds(), }, }, } nudDisp.mu.Lock() defer nudDisp.mu.Unlock() - if diff := cmp.Diff(wantUnsortedEvents, nudDisp.mu.events, eventDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnorderedEvents, nudDisp.mu.events, unorderedEventsDiffOpts()...); diff != "" { t.Errorf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } } @@ -1071,56 +1012,30 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.linkRes.entries.entry(0) if !ok { - t.Fatal("c.linkRes.entries.entry(0) not found") - } - _, _, err := c.linkRes.neigh.entry(entry.Addr, "", nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got c.linkRes.neigh.entry(%s, '', nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - c.clock.Advance(typicalLatency) - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, + t.Fatal("got c.linkRes.entries.entry(0) = _, false, want = true ") } - c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) - c.nudDisp.mu.events = nil - c.nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(c.nudDisp, c.clock, c.linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Clear the cache. c.linkRes.neigh.clear() + { wantEvents := []testEntryEventInfo{ { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: c.clock.NowNanoseconds(), }, }, } c.nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, c.nudDisp.mu.events, eventDiffOpts()...) + diff := cmp.Diff(wantEvents, c.nudDisp.mu.events) c.nudDisp.mu.events = nil c.nudDisp.mu.Unlock() if diff != "" { @@ -1147,10 +1062,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - frequentlyUsedEntry, ok := linkRes.entries.entry(0) - if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } + startedAt := clock.NowNanoseconds() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. @@ -1159,50 +1071,18 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := 0; i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } } + frequentlyUsedEntry, ok := linkRes.entries.entry(0) + if !ok { + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") + } + // Keep adding more entries for i := neighborCacheSize; i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry @@ -1214,63 +1094,17 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Errorf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } // An entry should have been removed, as per the LRU eviction strategy removedEntry, ok := linkRes.entries.entry(i - neighborCacheSize + 1) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i-neighborCacheSize+1) - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestRemoved, - NICID: 1, - Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - }, - }, - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i-neighborCacheSize+1) } - nudDisp.mu.Lock() - diff := cmp.Diff(wantEvents, nudDisp.mu.events, eventDiffOpts()...) - nudDisp.mu.events = nil - nudDisp.mu.Unlock() - if diff != "" { - t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + + if err := addReachableEntryWithRemoved(&nudDisp, clock, linkRes, entry, []NeighborEntry{removedEntry}); err != nil { + t.Fatalf("addReachableEntryWithRemoved(...) = %s", err) } } @@ -1282,23 +1116,27 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { Addr: frequentlyUsedEntry.Addr, LinkAddr: frequentlyUsedEntry.LinkAddr, State: Reachable, + // Can be inferred since the frequently used entry is the first to + // be created and transitioned to Reachable. + UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), }, } for i := linkRes.entries.size() - neighborCacheSize + 1; i < linkRes.entries.size(); i++ { entry, ok := linkRes.entries.entry(i) if !ok { - t.Fatalf("linkRes.entries.entry(%d) not found", i) - } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + }) } - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } @@ -1350,17 +1188,18 @@ func TestNeighborCacheConcurrent(t *testing.T) { for i := linkRes.entries.size() - neighborCacheSize; i < linkRes.entries.size(); i++ { entry, ok := linkRes.entries.entry(i) if !ok { - t.Errorf("linkRes.entries.entry(%d) not found", i) + t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + }) } - if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), entryDiffOptsWithSort()...); diff != "" { + if diff := cmp.Diff(wantUnsortedEntries, linkRes.neigh.entries(), unorderedEntriesDiffOpts()...); diff != "" { t.Errorf("neighbor entries mismatch (-want, +got):\n%s", diff) } } @@ -1372,44 +1211,12 @@ func TestNeighborCacheReplace(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - // Add an entry entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") - } - - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } - - // Verify the entry exists - { - e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) - } - if t.Failed() { - t.FailNow() - } - want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { - t.Errorf("linkRes.neigh.entry(%s, '', _, _, nil) mismatch (-want, +got):\n%s", entry.Addr, diff) - } + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } // Notify of a link address change @@ -1417,7 +1224,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := linkRes.entries.entry(1) if !ok { - t.Fatal("linkRes.entries.entry(1) not found") + t.Fatal("got linkRes.entries.entry(1) = _, false, want = true") } updatedLinkAddr = entry.LinkAddr } @@ -1437,29 +1244,31 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAtNanos: clock.NowNanoseconds(), } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } - clock.Advance(config.DelayFirstProbeTime + typicalLatency) } + clock.Advance(config.DelayFirstProbeTime + typicalLatency) + // Verify that the neighbor is now reachable. { e, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), } - if diff := cmp.Diff(want, e, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } } @@ -1479,25 +1288,12 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } // First, sanity check that resolution is working - { - _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { - if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Err: nil}, r); diff != "" { - t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) - } - }) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) - } - clock.Advance(typicalLatency) - select { - case <-ch: - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) - } + if err := addReachableEntry(&nudDisp, clock, linkRes, entry); err != nil { + t.Fatalf("addReachableEntry(...) = %s", err) } got, _, err := linkRes.neigh.entry(entry.Addr, "", nil) @@ -1505,11 +1301,12 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), } - if diff := cmp.Diff(want, got, entryDiffOpts()...); diff != "" { + if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) } @@ -1524,14 +1321,14 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } } @@ -1555,7 +1352,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { @@ -1564,7 +1361,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) @@ -1572,7 +1369,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } } @@ -1580,14 +1377,15 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { // failing to perform address resolution. func TestNeighborCacheRetryResolution(t *testing.T) { config := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} clock := faketime.NewManualClock() - linkRes := newTestNeighborResolver(nil, config, clock) + linkRes := newTestNeighborResolver(&nudDisp, config, clock) // Simulate a faulty link. linkRes.dropReplies = true entry, ok := linkRes.entries.entry(0) if !ok { - t.Fatal("linkRes.entries.entry(0) not found") + t.Fatal("got linkRes.entries.entry(0) = _, false, want = true ") } // Perform address resolution with a faulty link, which will fail. @@ -1598,27 +1396,75 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + t.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) select { case <-ch: default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } - } - wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - }, - } - if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { - t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + } + + { + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + } + if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + } } // Retry address resolution with a working link. @@ -1635,28 +1481,74 @@ func TestNeighborCacheRetryResolution(t *testing.T) { if incompleteEntry.State != Incomplete { t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) + } + } + clock.Advance(typicalLatency) select { case <-ch: - if !ok { - t.Fatal("expected successful address resolution") + default: + t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) + } + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), + }, + }, } - reachableEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) - if err != nil { - t.Fatalf("linkRes.neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + nudDisp.mu.Lock() + diff := cmp.Diff(wantEvents, nudDisp.mu.events) + nudDisp.mu.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-want, +got):\n%s", diff) } - if reachableEntry.Addr != entry.Addr { - t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + + { + gotEntry, _, err := linkRes.neigh.entry(entry.Addr, "", nil) + if err != nil { + t.Fatalf("linkRes.neigh.entry(%s, '', _): %s", entry.Addr, err) } - if reachableEntry.LinkAddr != entry.LinkAddr { - t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAtNanos: clock.NowNanoseconds(), } - if reachableEntry.State != Reachable { - t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { + t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) } - default: - t.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) } } } @@ -1674,7 +1566,7 @@ func BenchmarkCacheClear(b *testing.B) { for i := 0; i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { - b.Fatalf("linkRes.entries.entry(%d) not found", i) + b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } _, ch, err := linkRes.neigh.entry(entry.Addr, "", func(r LinkResolutionResult) { @@ -1683,13 +1575,13 @@ func BenchmarkCacheClear(b *testing.B) { } }) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - b.Fatalf("got linkRes.neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) + b.Fatalf("got linkRes.neigh.entry(%s, '', _) = %v, want = %s", entry.Addr, err, &tcpip.ErrWouldBlock{}) } select { case <-ch: default: - b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _, _, nil)", entry.Addr) + b.Fatalf("expected notification from done channel returned by linkRes.neigh.entry(%s, '', _)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index baae7dfe1..bb2b2d705 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -18,13 +18,11 @@ import ( "fmt" "math" "math/rand" - "strings" "sync" "testing" "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -52,23 +50,6 @@ func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { clock.Advance(immediateDuration) } -// eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAtNanos field is ignored due to a lack of a deterministic method -// to predict the time that an event will be dispatched. -func eventDiffOpts() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), - } -} - -// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to -// sort slices of events for cases where ordering must be ignored. -func eventDiffOptsWithSort() []cmp.Option { - return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 - })) -} - // The following unit tests exercise every state transition and verify its // behavior with RFC 4681 and RFC 7048. // diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f9323d545..62f7c880e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -725,12 +725,12 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.mu.RUnlock() n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) return } n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { @@ -881,7 +881,7 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) + transHeader, ok := pkt.Data().PullUp(8) if !ok { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4f013b212..8f288675d 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -59,7 +59,7 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. + // data holds the payload of the packet. // // For inbound packets, Data is initially the whole packet. Then gets moved to // headers via PacketHeader.Consume, when the packet is being parsed. @@ -69,7 +69,7 @@ type PacketBuffer struct { // // The bytes backing Data are immutable, a.k.a. users shouldn't write to its // backing storage. - Data buffer.VectorisedView + data buffer.VectorisedView // headers stores metadata about each header. headers [numHeaderType]headerInfo @@ -127,7 +127,7 @@ type PacketBuffer struct { // NewPacketBuffer creates a new PacketBuffer with opts. func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { pk := &PacketBuffer{ - Data: opts.Data, + data: opts.Data, } if opts.ReserveHeaderBytes != 0 { pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) @@ -184,13 +184,18 @@ func (pk *PacketBuffer) HeaderSize() int { // Size returns the size of packet in bytes. func (pk *PacketBuffer) Size() int { - return pk.HeaderSize() + pk.Data.Size() + return pk.HeaderSize() + pk.data.Size() } // MemSize returns the estimation size of the pk in memory, including backing // buffer data. func (pk *PacketBuffer) MemSize() int { - return pk.HeaderSize() + pk.Data.MemSize() + packetBufferStructSize + return pk.HeaderSize() + pk.data.MemSize() + packetBufferStructSize +} + +// Data returns the handle to data portion of pk. +func (pk *PacketBuffer) Data() PacketData { + return PacketData{pk: pk} } // Views returns the underlying storage of the whole packet. @@ -204,7 +209,7 @@ func (pk *PacketBuffer) Views() []buffer.View { } } - dataViews := pk.Data.Views() + dataViews := pk.data.Views() var vs []buffer.View if useHeader { @@ -242,11 +247,11 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum if h.buf != nil { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - v, ok := pk.Data.PullUp(size) + v, ok := pk.data.PullUp(size) if !ok { return } - pk.Data.TrimFront(size) + pk.data.TrimFront(size) h.buf = v return h.buf, true } @@ -258,7 +263,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum func (pk *PacketBuffer) Clone() *PacketBuffer { return &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), + data: pk.data.Clone(nil), headers: pk.headers, header: pk.header, Hash: pk.Hash, @@ -339,13 +344,234 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { return h.pk.consume(h.typ, size) } +// PacketData represents the data portion of a PacketBuffer. +type PacketData struct { + pk *PacketBuffer +} + +// PullUp returns a contiguous view of size bytes from the beginning of d. +// Callers should not write to or keep the view for later use. +func (d PacketData) PullUp(size int) (buffer.View, bool) { + return d.pk.data.PullUp(size) +} + +// TrimFront removes count from the beginning of d. It panics if count > +// d.Size(). +func (d PacketData) TrimFront(count int) { + d.pk.data.TrimFront(count) +} + +// CapLength reduces d to at most length bytes. +func (d PacketData) CapLength(length int) { + d.pk.data.CapLength(length) +} + +// Views returns the underlying storage of d in a slice of Views. Caller should +// not modify the returned slice. +func (d PacketData) Views() []buffer.View { + return d.pk.data.Views() +} + +// AppendView appends v into d, taking the ownership of v. +func (d PacketData) AppendView(v buffer.View) { + d.pk.data.AppendView(v) +} + +// ReadFromData moves at most count bytes from the beginning of srcData to the +// end of d and returns the number of bytes moved. +func (d PacketData) ReadFromData(srcData PacketData, count int) int { + return srcData.pk.data.ReadToVV(&d.pk.data, count) +} + +// ReadFromVV moves at most count bytes from the beginning of srcVV to the end +// of d and returns the number of bytes moved. +func (d PacketData) ReadFromVV(srcVV *buffer.VectorisedView, count int) int { + return srcVV.ReadToVV(&d.pk.data, count) +} + +// Size returns the number of bytes in the data payload of the packet. +func (d PacketData) Size() int { + return d.pk.data.Size() +} + +// AsRange returns a Range representing the current data payload of the packet. +func (d PacketData) AsRange() Range { + return Range{ + pk: d.pk, + offset: d.pk.HeaderSize(), + length: d.Size(), + } +} + +// ExtractVV returns a VectorisedView of d. This method has the semantic to +// destruct the underlying packet, hence the packet cannot be used again. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) ExtractVV() buffer.VectorisedView { + return d.pk.data +} + +// Replace replaces the data portion of the packet with vv, taking the ownership +// of vv. +// +// This method exists for compatibility between PacketBuffer and VectorisedView. +// It may be removed later and should be used with care. +func (d PacketData) Replace(vv buffer.VectorisedView) { + d.pk.data = vv +} + +// Range represents a contiguous subportion of a PacketBuffer. +type Range struct { + pk *PacketBuffer + offset int + length int +} + +// Size returns the number of bytes in r. +func (r Range) Size() int { + return r.length +} + +// SubRange returns a new Range starting at off bytes of r. It returns an empty +// range if off is out-of-bounds. +func (r Range) SubRange(off int) Range { + if off > r.length { + return Range{pk: r.pk} + } + return Range{ + pk: r.pk, + offset: r.offset + off, + length: r.length - off, + } +} + +// Capped returns a new Range with the same starting point of r and length +// capped at max. +func (r Range) Capped(max int) Range { + if r.length <= max { + return r + } + return Range{ + pk: r.pk, + offset: r.offset, + length: max, + } +} + +// AsView returns the backing storage of r if possible. It will allocate a new +// View if r spans multiple pieces internally. Caller should not write to the +// returned View in any way. +func (r Range) AsView() buffer.View { + var allocated bool + var v buffer.View + r.iterate(func(b []byte) { + if v == nil { + // v has not been assigned, allowing first view to be returned. + v = b + } else { + // v has been assigned. This range spans more than a view, a new view + // needs to be allocated. + if !allocated { + allocated = true + all := make([]byte, 0, r.length) + all = append(all, v...) + v = all + } + v = append(v, b...) + } + }) + return v +} + +// ToOwnedView returns a owned copy of data in r. +func (r Range) ToOwnedView() buffer.View { + if r.length == 0 { + return nil + } + all := make([]byte, 0, r.length) + r.iterate(func(b []byte) { + all = append(all, b...) + }) + return all +} + +// Checksum calculates the RFC 1071 checksum for the underlying bytes of r. +func (r Range) Checksum() uint16 { + var c header.Checksumer + r.iterate(c.Add) + return c.Checksum() +} + +// iterate calls fn for each piece in r. fn is always called with a non-empty +// slice. +func (r Range) iterate(fn func([]byte)) { + w := window{ + offset: r.offset, + length: r.length, + } + // Header portion. + for i := range r.pk.headers { + if b := w.process(r.pk.headers[i].buf); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + // Data portion. + if !w.isDone() { + for _, v := range r.pk.data.Views() { + if b := w.process(v); len(b) > 0 { + fn(b) + } + if w.isDone() { + break + } + } + } +} + +// window represents contiguous region of byte stream. User would call process() +// to input bytes, and obtain a subslice that is inside the window. +type window struct { + offset int + length int +} + +// isDone returns true if the window has passed and further process() calls will +// always return an empty slice. This can be used to end processing early. +func (w *window) isDone() bool { + return w.length == 0 +} + +// process feeds b in and returns a subslice that is inside the window. The +// returned slice will be a subslice of b, and it does not keep b after method +// returns. This method may return an empty slice if nothing in b is inside the +// window. +func (w *window) process(b []byte) (inWindow []byte) { + if w.offset >= len(b) { + w.offset -= len(b) + return nil + } + if w.offset > 0 { + b = b[w.offset:] + w.offset = 0 + } + if w.length < len(b) { + b = b[:w.length] + } + w.length -= len(b) + return b +} + // PayloadSince returns packet payload starting from and including a particular // header. // // The returned View is owned by the caller - its backing buffer is separate // from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - size := h.pk.Data.Size() + size := h.pk.data.Size() for _, hinfo := range h.pk.headers[h.typ:] { size += len(hinfo.buf) } @@ -356,7 +582,7 @@ func PayloadSince(h PacketHeader) buffer.View { v = append(v, hinfo.buf...) } - for _, view := range h.pk.Data.Views() { + for _, view := range h.pk.data.Views() { v = append(v, view...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index c6fa8da5f..6728370c3 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -15,9 +15,11 @@ package stack import ( "bytes" + "fmt" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) func TestPacketHeaderPush(t *testing.T) { @@ -110,7 +112,7 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkData(t, pk, test.data) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), concatViews(test.link, test.network, test.transport, test.data)) // Check the after values for each header. @@ -204,7 +206,7 @@ func TestPacketHeaderConsume(t *testing.T) { transport = test.data[test.link+test.network:][:test.transport] payload = test.data[allHdrSize:] ) - checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkData(t, pk, payload) checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) // Check the after values for each header. checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) @@ -340,6 +342,158 @@ func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { } } +func TestPacketBufferData(t *testing.T) { + for _, tc := range []struct { + name string + makePkt func(*testing.T) *PacketBuffer + data string + }{ + { + name: "inbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv("aabbbbccccccDATA"), + }) + pkt.LinkHeader().Consume(2) + pkt.NetworkHeader().Consume(4) + pkt.TransportHeader().Consume(6) + return pkt + }, + data: "DATA", + }, + { + name: "outbound packet", + makePkt: func(*testing.T) *PacketBuffer { + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: 12, + Data: vv("DATA"), + }) + copy(pkt.TransportHeader().Push(6), []byte("cccccc")) + copy(pkt.NetworkHeader().Push(4), []byte("bbbb")) + copy(pkt.LinkHeader().Push(2), []byte("aa")) + return pkt + }, + data: "DATA", + }, + } { + t.Run(tc.name, func(t *testing.T) { + // PullUp + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("PullUp%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + wantV := []byte(tc.data)[:n] + if !ok || !bytes.Equal(v, wantV) { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want %q, true", n, v, ok, wantV) + } + }) + } + t.Run("PullUpOutOfBounds", func(t *testing.T) { + n := len(tc.data) + 1 + pkt := tc.makePkt(t) + v, ok := pkt.Data().PullUp(n) + if ok || v != nil { + t.Errorf("pkt.Data().PullUp(%d) = %q, %t; want nil, false", n, v, ok) + } + }) + + // TrimFront + for _, n := range []int{1, len(tc.data)} { + t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().TrimFront(n) + + checkData(t, pkt, []byte(tc.data)[n:]) + }) + } + + // CapLength + for _, n := range []int{0, 1, len(tc.data)} { + t.Run(fmt.Sprintf("CapLength%d", n), func(t *testing.T) { + pkt := tc.makePkt(t) + pkt.Data().CapLength(n) + + want := []byte(tc.data) + if n < len(want) { + want = want[:n] + } + checkData(t, pkt, want) + }) + } + + // Views + t.Run("Views", func(t *testing.T) { + pkt := tc.makePkt(t) + checkData(t, pkt, []byte(tc.data)) + }) + + // AppendView + t.Run("AppendView", func(t *testing.T) { + s := "APPEND" + + pkt := tc.makePkt(t) + pkt.Data().AppendView(buffer.View(s)) + + checkData(t, pkt, []byte(tc.data+s)) + }) + + // ReadFromData/VV + for _, n := range []int{0, 1, 2, 7, 10, 14, 20} { + t.Run(fmt.Sprintf("ReadFromData%d", n), func(t *testing.T) { + s := "TO READ" + otherPkt := NewPacketBuffer(PacketBufferOptions{ + Data: vv(s, s), + }) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromData(otherPkt.Data(), n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + t.Run(fmt.Sprintf("ReadFromVV%d", n), func(t *testing.T) { + s := "TO READ" + srcVV := vv(s, s) + s += s + + pkt := tc.makePkt(t) + pkt.Data().ReadFromVV(&srcVV, n) + + if n < len(s) { + s = s[:n] + } + checkData(t, pkt, []byte(tc.data+s)) + }) + } + + // ExtractVV + t.Run("ExtractVV", func(t *testing.T) { + pkt := tc.makePkt(t) + extractedVV := pkt.Data().ExtractVV() + + got := extractedVV.ToOwnedView() + want := []byte(tc.data) + if !bytes.Equal(got, want) { + t.Errorf("pkt.Data().ExtractVV().ToOwnedView() = %q, want %q", got, want) + } + }) + + // Replace + t.Run("Replace", func(t *testing.T) { + s := "REPLACED" + + pkt := tc.makePkt(t) + pkt.Data().Replace(vv(s)) + + checkData(t, pkt, []byte(s)) + }) + }) + } +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -356,7 +510,7 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkData(t, pk, data) checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) // Check the initial values for each header. checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) @@ -383,6 +537,70 @@ func checkViewEqual(t *testing.T, what string, got, want buffer.View) { } } +func checkData(t *testing.T, pkt *PacketBuffer, want []byte) { + t.Helper() + if got := concatViews(pkt.Data().Views()...); !bytes.Equal(got, want) { + t.Errorf("pkt.Data().Views() = %x, want %x", got, want) + } + if got := pkt.Data().Size(); got != len(want) { + t.Errorf("pkt.Data().Size() = %d, want %d", got, len(want)) + } + + t.Run("AsRange", func(t *testing.T) { + // Full range + checkRange(t, pkt.Data().AsRange(), want) + + // SubRange + for _, off := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("SubRange%d", off), func(t *testing.T) { + // Empty when off is greater than the size of range. + var sub []byte + if off < len(want) { + sub = want[off:] + } + checkRange(t, pkt.Data().AsRange().SubRange(off), sub) + }) + } + + // Capped + for _, n := range []int{0, 1, len(want), len(want) + 1} { + t.Run(fmt.Sprintf("Capped%d", n), func(t *testing.T) { + sub := want + if n < len(sub) { + sub = sub[:n] + } + checkRange(t, pkt.Data().AsRange().Capped(n), sub) + }) + } + }) +} + +func checkRange(t *testing.T, r Range, data []byte) { + if got, want := r.Size(), len(data); got != want { + t.Errorf("r.Size() = %d, want %d", got, want) + } + if got := r.AsView(); !bytes.Equal(got, data) { + t.Errorf("r.AsView() = %x, want %x", got, data) + } + if got := r.ToOwnedView(); !bytes.Equal(got, data) { + t.Errorf("r.ToOwnedView() = %x, want %x", got, data) + } + if got, want := r.Checksum(), header.Checksum(data, 0 /* initial */); got != want { + t.Errorf("r.Checksum() = %x, want %x", got, want) + } +} + +func vv(pieces ...string) buffer.VectorisedView { + var views []buffer.View + var size int + for _, p := range pieces { + v := buffer.View([]byte(p)) + size += len(v) + views = append(views, v) + } + return buffer.NewVectorisedView(size, views) +} + func makeView(size int) buffer.View { b := byte(size) return bytes.Repeat([]byte{b}, size) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 43e9e4beb..85f0f471a 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -852,18 +852,46 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// DADResult is the result of a duplicate address detection process. -type DADResult struct { - // Resolved is true when DAD completed without detecting a duplicate address - // on the link. - // - // Ignored when Err is non-nil. - Resolved bool +// DADResult is a marker interface for the result of a duplicate address +// detection process. +type DADResult interface { + isDADResult() +} + +var _ DADResult = (*DADSucceeded)(nil) + +// DADSucceeded indicates DAD completed without finding any duplicate addresses. +type DADSucceeded struct{} - // Err is an error encountered while performing DAD. +func (*DADSucceeded) isDADResult() {} + +var _ DADResult = (*DADError)(nil) + +// DADError indicates DAD hit an error. +type DADError struct { Err tcpip.Error } +func (*DADError) isDADResult() {} + +var _ DADResult = (*DADAborted)(nil) + +// DADAborted indicates DAD was aborted. +type DADAborted struct{} + +func (*DADAborted) isDADResult() {} + +var _ DADResult = (*DADDupAddrDetected)(nil) + +// DADDupAddrDetected indicates DAD detected a duplicate address. +type DADDupAddrDetected struct { + // HolderLinkAddress is the link address of the node that holds the duplicate + // address. + HolderLinkAddress tcpip.LinkAddress +} + +func (*DADDupAddrDetected) isDADResult() {} + // DADCompletionHandler is a handler for DAD completion. type DADCompletionHandler func(DADResult) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index de94ddfda..53370c354 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -813,6 +813,18 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { return forwardingProtocol.Forwarding() } +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (s *Stack) PortRange() (uint16, uint16) { + return s.PortManager.PortRange() +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (s *Stack) SetPortRange(start uint16, end uint16) tcpip.Error { + return s.PortManager.SetPortRange(start, end) +} + // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. // diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8e39e828c..880219007 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,11 +137,11 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) + pkt.Data().TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -2605,7 +2605,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -3289,7 +3289,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -4294,7 +4294,7 @@ func TestWritePacketToRemote(t *testing.T) { if pkt.Route.RemoteLinkAddress != linkAddr2 { t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } - if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + if diff := cmp.Diff(pkt.Pkt.Data().AsRange().ToOwnedView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) } }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e799f9290..e188efccb 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -359,7 +359,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } - if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { return mpep.endpoints[len(mpep.endpoints)-1] } @@ -410,7 +410,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { return &tcpip.ErrPortInUse{} } } @@ -429,7 +429,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { return &tcpip.ErrPortInUse{} } } |