diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarding_test.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 64 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 59 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_stats.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud.go | 71 | ||||
-rw-r--r-- | pkg/tcpip/stack/nud_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 173 |
13 files changed, 307 insertions, 229 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..395ff9a07 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7107d598d..d971db010 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -114,10 +114,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -134,7 +130,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -319,7 +315,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,7 +350,7 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..90881169d 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -1068,7 +1060,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..463d017fc 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -307,7 +307,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -361,7 +361,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +378,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..c2a291244 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dbba2c79f..378389db2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } @@ -807,20 +794,20 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +839,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..02f905351 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -323,26 +323,21 @@ type Rand interface { type NUDState struct { rng Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex - - config NUDConfigurations - - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration - - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + mu struct { + sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified @@ -351,7 +346,7 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { s := &NUDState{ rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +354,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +372,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if time.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +403,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = time.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..6ba97d626 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -28,17 +28,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 01652fbe7..4ca702121 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int { func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] if h.length > 0 { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) } if pk.pushed+size > pk.reserved { - panic("not enough headroom reserved") + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } pk.pushed += size h.offset = -pk.pushed diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8814f45a6..72760a4a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -40,13 +40,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -804,7 +797,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -856,7 +849,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02d54d29b..73e0f0d58 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + nicStats := s.NICInfo()[nicid].Stats + + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -3837,8 +3874,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3875,8 +3910,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4223,7 +4256,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { |