diff options
author | Arthur Sfez <asfez@google.com> | 2021-01-19 15:05:17 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-19 15:07:39 -0800 |
commit | be17b94446b2f96c2a3d531fe20271537c77c8aa (patch) | |
tree | 1274841ecbb71f37195676354908deea9bf0d24c /pkg/tcpip/network | |
parent | 833ba3590b422d453012e5b2ec2e780211d9caf9 (diff) |
Per NIC NetworkEndpoint statistics
To facilitate the debugging of multi-homed setup, track Network
protocols statistics for each endpoint. Note that the original
stack-wide stats still exist.
A new type of statistic counter is introduced, which track two
versions of a stat at the same time. This lets a network endpoint
increment both the local stat and the stack-wide stat at the same
time.
Fixes #4605
PiperOrigin-RevId: 352663276
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/BUILD | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/stats.go | 100 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 13 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 115 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/stats.go | 190 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/stats_test.go | 101 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 135 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 111 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 46 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/stats.go | 132 | ||||
-rw-r--r-- | pkg/tcpip/network/testutil/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/testutil/testutil.go | 68 | ||||
-rw-r--r-- | pkg/tcpip/network/testutil/testutil_unsafe.go | 26 |
19 files changed, 973 insertions, 210 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index d54e7fe86..1d4d2966e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -210,6 +210,20 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + // TODO(gvisor.dev/issues/4963): Record statistics for ARP. + return &Stats{} +} + +var _ stack.NetworkEndpointStats = (*Stats)(nil) + +// Stats holds ARP statistics. +type Stats struct{} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (*Stats) IsNetworkEndpointStats() {} + // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { stack *stack.Stack diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index ca1247c1e..411bca25d 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "ip", - srcs = ["generic_multicast_protocol.go"], + srcs = [ + "generic_multicast_protocol.go", + "stats.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/sync", diff --git a/pkg/tcpip/network/ip/stats.go b/pkg/tcpip/network/ip/stats.go new file mode 100644 index 000000000..898f8b356 --- /dev/null +++ b/pkg/tcpip/network/ip/stats.go @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import "gvisor.dev/gvisor/pkg/tcpip" + +// LINT.IfChange(MultiCounterIPStats) + +// MultiCounterIPStats holds IP statistics, each counter may have several +// versions. +type MultiCounterIPStats struct { + // PacketsReceived is the total number of IP packets received from the link + // layer. + PacketsReceived tcpip.MultiCounterStat + + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived tcpip.MultiCounterStat + + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived tcpip.MultiCounterStat + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived tcpip.MultiCounterStat + + // PacketsDelivered is the total number of incoming IP packets that are + // successfully delivered to the transport layer. + PacketsDelivered tcpip.MultiCounterStat + + // PacketsSent is the total number of IP packets sent via WritePacket. + PacketsSent tcpip.MultiCounterStat + + // OutgoingPacketErrors is the total number of IP packets which failed to + // write to a link-layer endpoint. + OutgoingPacketErrors tcpip.MultiCounterStat + + // MalformedPacketsReceived is the total number of IP Packets that were + // dropped due to the IP packet header failing validation checks. + MalformedPacketsReceived tcpip.MultiCounterStat + + // MalformedFragmentsReceived is the total number of IP Fragments that were + // dropped due to the fragment failing validation checks. + MalformedFragmentsReceived tcpip.MultiCounterStat + + // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // Prerouting chain. + IPTablesPreroutingDropped tcpip.MultiCounterStat + + // IPTablesInputDropped is the total number of IP packets dropped in the Input + // chain. + IPTablesInputDropped tcpip.MultiCounterStat + + // IPTablesOutputDropped is the total number of IP packets dropped in the + // Output chain. + IPTablesOutputDropped tcpip.MultiCounterStat + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived tcpip.MultiCounterStat + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived tcpip.MultiCounterStat + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived tcpip.MultiCounterStat +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { + m.PacketsReceived.Init(a.PacketsReceived, b.PacketsReceived) + m.DisabledPacketsReceived.Init(a.DisabledPacketsReceived, b.DisabledPacketsReceived) + m.InvalidDestinationAddressesReceived.Init(a.InvalidDestinationAddressesReceived, b.InvalidDestinationAddressesReceived) + m.InvalidSourceAddressesReceived.Init(a.InvalidSourceAddressesReceived, b.InvalidSourceAddressesReceived) + m.PacketsDelivered.Init(a.PacketsDelivered, b.PacketsDelivered) + m.PacketsSent.Init(a.PacketsSent, b.PacketsSent) + m.OutgoingPacketErrors.Init(a.OutgoingPacketErrors, b.OutgoingPacketErrors) + m.MalformedPacketsReceived.Init(a.MalformedPacketsReceived, b.MalformedPacketsReceived) + m.MalformedFragmentsReceived.Init(a.MalformedFragmentsReceived, b.MalformedFragmentsReceived) + m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) + m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) + m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.OptionTSReceived.Init(a.OptionTSReceived, b.OptionTSReceived) + m.OptionRRReceived.Init(a.OptionRRReceived, b.OptionRRReceived) + m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) +} + +// LINT.ThenChange(:MultiCounterIPStats, ../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 32f53f217..330a7d170 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -8,6 +8,7 @@ go_library( "icmp.go", "igmp.go", "ipv4.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ @@ -49,3 +50,15 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":ipv4", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 8e392f86c..3f60de749 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -62,21 +62,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - received := stats.ICMP.V4.PacketsReceived + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { - received.Invalid.Increment() + received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because // we are the only handler, however other types do not cope well with @@ -117,8 +116,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + e.stats.ip.MalformedPacketsReceived.Increment() } return } @@ -128,11 +127,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: - received.Echo.Increment() + received.echo.Increment() - sent := stats.ICMP.V4.PacketsSent + sent := e.stats.icmp.packetsSent if !e.protocol.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return } @@ -213,18 +212,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv4EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { @@ -243,31 +242,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } case header.ICMPv4SrcQuench: - received.SrcQuench.Increment() + received.srcQuench.Increment() case header.ICMPv4Redirect: - received.Redirect.Increment() + received.redirect.Increment() case header.ICMPv4TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv4ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv4Timestamp: - received.Timestamp.Increment() + received.timestamp.Increment() case header.ICMPv4TimestampReply: - received.TimestampReply.Increment() + received.timestampReply.Increment() case header.ICMPv4InfoRequest: - received.InfoRequest.Increment() + received.infoRequest.Increment() case header.ICMPv4InfoReply: - received.InfoReply.Increment() + received.infoReply.Increment() default: - received.Invalid.Increment() + received.invalid.Increment() } } @@ -379,9 +378,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - sent := p.stack.Stats().ICMP.V4.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -471,29 +478,29 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonProtoUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonParamProblem: icmpHdr.SetType(header.ICMPv4ParamProblem) icmpHdr.SetCode(header.ICMPv4UnusedCode) icmpHdr.SetPointer(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -508,7 +515,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, icmpPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index d9b5fe6ed..9515fde45 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -149,11 +149,10 @@ func (igmp *igmpState) init(ep *endpoint) { // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { - stats := igmp.ep.protocol.stack.Stats() - received := stats.IGMP.PacketsReceived + received := igmp.ep.stats.igmp.packetsReceived headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.IGMP(headerView) @@ -166,34 +165,34 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { h.SetChecksum(wantChecksum) if gotChecksum != wantChecksum { - received.ChecksumErrors.Increment() + received.checksumErrors.Increment() return } switch h.Type() { case header.IGMPMembershipQuery: - received.MembershipQuery.Increment() + received.membershipQuery.Increment() if len(headerView) < header.IGMPQueryMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) case header.IGMPv1MembershipReport: - received.V1MembershipReport.Increment() + received.v1MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPv2MembershipReport: - received.V2MembershipReport.Increment() + received.v2MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPLeaveGroup: - received.LeaveGroup.Increment() + received.leaveGroup.Increment() // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or // Report, are ignored in all states" @@ -201,7 +200,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should // be silently ignored. New message types may be used by newer versions of // IGMP, by multicast routing protocols, or other uses." - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -272,18 +271,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip panic(fmt.Sprintf("failed to add IP header: %s", err)) } - sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.stats.igmp.packetsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sentStats.V1MembershipReport.Increment() + sentStats.v1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sentStats.V2MembershipReport.Increment() + sentStats.v2MembershipReport.Increment() case header.IGMPLeaveGroup: - sentStats.LeaveGroup.Increment() + sentStats.leaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bb25a76fe..7f03696ae 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "math" + "reflect" "sync/atomic" "time" @@ -73,6 +74,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + stats sharedStats // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -114,9 +116,27 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa e.mu.addressableEndpointState.Init(e) e.mu.igmp.init(e) e.mu.Unlock() + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4) + e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() *tcpip.Error { e.mu.Lock() @@ -305,7 +325,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -349,9 +369,11 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } @@ -363,16 +385,16 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } @@ -385,6 +407,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return 0, err @@ -392,7 +416,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } @@ -421,13 +445,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -451,15 +475,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err } n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } @@ -577,11 +601,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } @@ -589,7 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.nic.IsLoopback() { if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -601,11 +626,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -631,7 +656,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -643,7 +668,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // be one of its own IP addresses (but not a broadcast or // multicast address). if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } // Make sure the source address is not a subnet-local broadcast address. @@ -651,7 +676,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { subnet := addressEndpoint.Subnet() addressEndpoint.DecRef() if subnet.IsBroadcast(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } } @@ -664,7 +689,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.ip.InvalidDestinationAddressesReceived.Increment() return } @@ -676,7 +701,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // this machine and will not be forwarded. if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.ip.IPTablesInputDropped.Increment() return } @@ -684,8 +709,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -698,8 +723,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } @@ -720,8 +745,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } if !ready { @@ -734,7 +759,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } - stats.IP.PacketsDelivered.Increment() + stats.ip.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { @@ -766,8 +791,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() } return } @@ -800,6 +825,8 @@ func (e *endpoint) Close() { e.disableLocked() e.mu.addressableEndpointState.Cleanup() + + e.protocol.forgetEndpoint(e.nic.ID()) } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. @@ -911,6 +938,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.igmp.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -918,6 +950,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint + } + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -1095,6 +1135,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { options: opts, } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + p.mu.eps = make(map[tcpip.NICID]*endpoint) return p } } @@ -1379,7 +1420,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - If there is an error, information as to what it was was. // - The replacement option set. func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { - stats := e.protocol.stack.Stats() + stats := e.stats.ip opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1427,7 +1468,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt optLen := int(option.Size()) switch option := option.(type) { case *header.IPv4OptionTimestamp: - stats.IP.OptionTSReceived.Increment() + stats.OptionTSReceived.Increment() if usage.actions().timestamp != optionRemove { clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] @@ -1440,7 +1481,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } case *header.IPv4OptionRecordRoute: - stats.IP.OptionRRReceived.Increment() + stats.OptionRRReceived.Increment() if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) @@ -1452,7 +1493,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } default: - stats.IP.OptionUnknownReceived.Increment() + stats.OptionUnknownReceived.Increment() if usage.actions().unknown == optionPass { newBuffer := optIter.RemainingBuffer()[:optLen] // Arguments already heavily checked.. ignore result. diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go new file mode 100644 index 000000000..7620728f9 --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats.go @@ -0,0 +1,190 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv4 protocol family. +type Stats struct { + // IP holds IPv4 statistics. + IP tcpip.IPStats + + // IGMP holds IGMP statistics. + IGMP tcpip.IGMPStats + + // ICMP holds ICMPv4 statistics. + ICMP tcpip.ICMPv4Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (s *Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv4Stats + igmp multiCounterIGMPStats +} + +// LINT.IfChange(multiCounterICMPv4PacketStats) + +type multiCounterICMPv4PacketStats struct { + echo tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + srcQuench tcpip.MultiCounterStat + redirect tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + timestamp tcpip.MultiCounterStat + timestampReply tcpip.MultiCounterStat + infoRequest tcpip.MultiCounterStat + infoReply tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) { + m.echo.Init(a.Echo, b.Echo) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.srcQuench.Init(a.SrcQuench, b.SrcQuench) + m.redirect.Init(a.Redirect, b.Redirect) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.timestamp.Init(a.Timestamp, b.Timestamp) + m.timestampReply.Init(a.TimestampReply, b.TimestampReply) + m.infoRequest.Init(a.InfoRequest, b.InfoRequest) + m.infoReply.Init(a.InfoReply, b.InfoReply) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats) + +// LINT.IfChange(multiCounterICMPv4SentPacketStats) + +type multiCounterICMPv4SentPacketStats struct { + multiCounterICMPv4PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats) + +// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats) + +type multiCounterICMPv4ReceivedPacketStats struct { + multiCounterICMPv4PacketStats + invalid tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.invalid.Init(a.Invalid, b.Invalid) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv4Stats) + +type multiCounterICMPv4Stats struct { + packetsSent multiCounterICMPv4SentPacketStats + packetsReceived multiCounterICMPv4ReceivedPacketStats +} + +func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4Stats) + +// LINT.IfChange(multiCounterIGMPPacketStats) + +type multiCounterIGMPPacketStats struct { + membershipQuery tcpip.MultiCounterStat + v1MembershipReport tcpip.MultiCounterStat + v2MembershipReport tcpip.MultiCounterStat + leaveGroup tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) { + m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery) + m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport) + m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport) + m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup) +} + +// LINT.ThenChange(../../tcpip.go:IGMPPacketStats) + +// LINT.IfChange(multiCounterIGMPSentPacketStats) + +type multiCounterIGMPSentPacketStats struct { + multiCounterIGMPPacketStats + dropped tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.dropped.Init(a.Dropped, b.Dropped) +} + +// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats) + +// LINT.IfChange(multiCounterIGMPReceivedPacketStats) + +type multiCounterIGMPReceivedPacketStats struct { + multiCounterIGMPPacketStats + invalid tcpip.MultiCounterStat + checksumErrors tcpip.MultiCounterStat + unrecognized tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.invalid.Init(a.Invalid, b.Invalid) + m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) +} + +// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats) + +// LINT.IfChange(multiCounterIGMPStats) + +type multiCounterIGMPStats struct { + packetsSent multiCounterIGMPSentPacketStats + packetsReceived multiCounterIGMPReceivedPacketStats +} + +func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:IGMPStats) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go new file mode 100644 index 000000000..84641bcf4 --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func getKnownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + { + proto.mu.Lock() + foundEP, hasEP := proto.mu.eps[nic.ID()] + nicIDs := getKnownNICIDs(proto) + proto.mu.Unlock() + + if !hasEP { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("expected protocol to map endpoint %p to nic id %d, but endpoint %p was found instead", ep, nic.ID(), foundEP) + } + } + + ep.Close() + + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[nic.ID()] + nicIDs := getKnownNICIDs(proto) + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index afa45aefe..0c5f8d683 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,6 +10,7 @@ go_library( "ipv6.go", "mld.go", "ndp.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6ee162713..47e8aa11a 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -125,15 +125,14 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := e.protocol.stack.Stats().ICMP - sent := stats.V6.PacketsSent - received := stats.V6.PacketsReceived + sent := e.stats.icmp.packetsSent + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv6(v) @@ -147,7 +146,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -165,10 +164,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: - received.PacketTooBig.Increment() + received.packetTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) @@ -179,10 +178,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) @@ -194,9 +193,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6NeighborSolicit: - received.NeighborSolicit.Increment() + received.neighborSolicit.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -210,7 +209,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast // address. if header.IsV6MulticastAddress(targetAddr) { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -263,13 +262,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { if err != nil { // Options are not valid as per the wire format, silently drop the // packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok = getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } } @@ -282,11 +281,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { - received.Invalid.Increment() + received.invalid.Increment() return } } else if unspecifiedSource { - received.Invalid.Increment() + received.invalid.Increment() return } else if e.nud != nil { e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) @@ -301,7 +300,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -379,15 +378,15 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.NeighborAdvert.Increment() + sent.neighborAdvert.Increment() case header.ICMPv6NeighborAdvert: - received.NeighborAdvert.Increment() + received.neighborAdvert.Increment() if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -423,7 +422,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := na.Options().Iter(false /* check */) if err != nil { // If we have a malformed NDP NA option, drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } @@ -438,7 +437,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // DAD. targetLinkAddr, ok := getTargetLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -458,10 +457,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { }) case header.ICMPv6EchoRequest: - received.EchoRequest.Increment() + received.echoRequest.Increment() icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -493,27 +492,27 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, }, replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv6EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv6ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv6RouterSolicit: - received.RouterSolicit.Increment() + received.routerSolicit.Increment() // // Validate the RS as per RFC 4861 section 6.1.1. @@ -521,7 +520,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the NDP payload of sufficient size to hold a Router Solictation? if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -530,7 +529,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { // ... No, silently drop the packet. - received.RouterOnlyPacketsDroppedByHost.Increment() + received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -540,13 +539,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := rs.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok := getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -557,7 +556,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. if srcAddr == header.IPv6Any { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -569,7 +568,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6RouterAdvert: - received.RouterAdvert.Increment() + received.routerAdvert.Increment() // // Validate the RA as per RFC 4861 section 6.1.2. @@ -577,7 +576,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the NDP payload of sufficient size to hold a Router Advertisement? if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -586,7 +585,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { // ...No, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } @@ -596,13 +595,13 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { it, err := ra.Options().Iter(false /* check */) if err != nil { // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() + received.invalid.Increment() return } sourceLinkAddr, ok := getSourceLinkAddr(it) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -638,26 +637,26 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { // link-layer address be modified due to receiving one of the above // messages, the state SHOULD also be set to STALE to provide prompt // verification that the path to the new link-layer address is working." - received.RedirectMsg.Increment() + received.redirectMsg.Increment() if !isNDPValid() { - received.Invalid.Increment() + received.invalid.Increment() return } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: switch icmpType { case header.ICMPv6MulticastListenerQuery: - received.MulticastListenerQuery.Increment() + received.multicastListenerQuery.Increment() case header.ICMPv6MulticastListenerReport: - received.MulticastListenerReport.Increment() + received.multicastListenerReport.Increment() case header.ICMPv6MulticastListenerDone: - received.MulticastListenerDone.Increment() + received.multicastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } @@ -676,7 +675,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } default: - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -717,16 +716,23 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - stat := p.stack.Stats().ICMP.V6.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[nic.ID()] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + stat := netEP.stats.icmp.packetsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, pkt); err != nil { - stat.Dropped.Increment() + stat.dropped.Increment() return err } - stat.NeighborSolicit.Increment() + stat.neighborSolicit.Increment() return nil } @@ -863,10 +869,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - stats := p.stack.Stats().ICMP - sent := stats.V6.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -921,25 +934,25 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonParameterProblem: icmpHdr.SetType(header.ICMPv6ParamProblem) icmpHdr.SetCode(reason.code) icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -953,7 +966,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, newPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index ae4a8f508..37884505e 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -20,6 +20,7 @@ import ( "fmt" "hash/fnv" "math" + "reflect" "sort" "sync/atomic" "time" @@ -177,6 +178,7 @@ type endpoint struct { dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack + stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. @@ -632,7 +634,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -675,9 +677,10 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } @@ -689,17 +692,17 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } @@ -712,6 +715,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */); err != nil { @@ -720,7 +724,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } if packetMustBeFragmented(pb, networkMTU, gso) { @@ -733,7 +737,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pb = fragPkt return nil }); err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } // Remove the packet that was just fragmented and process the rest. @@ -749,13 +753,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -779,8 +783,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err @@ -788,7 +792,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } @@ -882,11 +886,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } @@ -894,7 +899,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.nic.IsLoopback() { if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -906,11 +911,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats.ip h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } srcAddr := h.SourceAddress() @@ -920,7 +925,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. if header.IsV6MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.InvalidSourceAddressesReceived.Increment() return } @@ -930,7 +935,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.InvalidDestinationAddressesReceived.Increment() return } @@ -952,7 +957,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // this machine and need not be forwarded. if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.IPTablesInputDropped.Increment() return } @@ -962,7 +967,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -986,7 +991,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -1075,8 +1080,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } if done { @@ -1103,8 +1108,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } } @@ -1112,8 +1117,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } @@ -1126,8 +1131,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // of the fragment, pointing to the Payload Length field of the // fragment packet. if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, @@ -1147,8 +1152,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // the fragment, pointing to the Fragment Offset field of the fragment // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, @@ -1173,8 +1178,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.MalformedPacketsReceived.Increment() + stats.MalformedFragmentsReceived.Increment() return } @@ -1194,7 +1199,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - stats.IP.MalformedPacketsReceived.Increment() + stats.MalformedPacketsReceived.Increment() return } if done { @@ -1244,12 +1249,12 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - stats.IP.PacketsDelivered.Increment() + stats.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p e.handleICMP(pkt, hasFragmentHeader) } else { - stats.IP.PacketsDelivered.Increment() + stats.PacketsDelivered.Increment() switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: @@ -1314,7 +1319,7 @@ func (e *endpoint) Close() { e.mu.addressableEndpointState.Cleanup() e.mu.Unlock() - e.protocol.forgetEndpoint(e) + e.protocol.forgetEndpoint(e.nic.ID()) } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. @@ -1662,6 +1667,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.mld.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1673,7 +1683,9 @@ type protocol struct { mu struct { sync.RWMutex - eps map[*endpoint]struct{} + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint } ids []uint32 @@ -1730,16 +1742,21 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L e.mu.mld.init(e) e.mu.Unlock() + stackStats := p.stack.Stats() + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V6) + p.mu.Lock() defer p.mu.Unlock() - p.mu.eps[e] = struct{}{} + p.mu.eps[nic.ID()] = e return e } -func (p *protocol) forgetEndpoint(e *endpoint) { +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() - delete(p.mu.eps, e) + delete(p.mu.eps, nicID) } // SetOption implements NetworkProtocol.SetOption. @@ -1814,7 +1831,7 @@ func (p *protocol) SetForwarding(v bool) { return } - for ep := range p.mu.eps { + for _, ep := range p.mu.eps { ep.transitionForwarding(v) } } @@ -1906,7 +1923,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { hashIV: hashIV, } p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) - p.mu.eps = make(map[*endpoint]struct{}) + p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) return p } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index b65c9d060..aa892d043 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -21,6 +21,7 @@ import ( "io/ioutil" "math" "net" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -2582,18 +2583,33 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, return false, false } +func getKnownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + func TestClearEndpointFromProtocolOnClose(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) - ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) { proto.mu.Lock() - _, hasEP := proto.mu.eps[ep] + foundEP, hasEP := proto.mu.eps[nic.ID()] + nicIDs := getKnownNICIDs(proto) proto.mu.Unlock() if !hasEP { - t.Fatalf("expected protocol to have ep = %p in set of endpoints", ep) + t.Fatalf("expected to find the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("expected protocol to map endpoint %p to nic id %d, but endpoint %p was found instead", ep, nic.ID(), foundEP) } } @@ -2601,10 +2617,11 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) { { proto.mu.Lock() - _, hasEP := proto.mu.eps[ep] + _, hasEP := proto.mu.eps[nic.ID()] + nicIDs := getKnownNICIDs(proto) proto.mu.Unlock() if hasEP { - t.Fatalf("unexpectedly found ep = %p in set of protocol's endpoints", ep) + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) } } } @@ -3053,3 +3070,22 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // supposed to be bound. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refStack.IP).Elem(), reflect.ValueOf(&refEP.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refStack.ICMP.V6).Elem(), reflect.ValueOf(&refEP.ICMP).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index ec54d88cc..78d86e523 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -167,13 +167,13 @@ func (mld *mldState) sendQueuedReports() { // // Precondition: mld.ep.mu must be read locked. func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { - sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent - var mldStat *tcpip.StatCounter + sentStats := mld.ep.stats.icmp.packetsSent + var mldStat tcpip.MultiCounterStat switch mldType { case header.ICMPv6MulticastListenerReport: - mldStat = sentStats.MulticastListenerReport + mldStat = sentStats.multicastListenerReport case header.ICMPv6MulticastListenerDone: - mldStat = sentStats.MulticastListenerDone + mldStat = sentStats.multicastListenerDone default: panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) } @@ -256,7 +256,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp panic(fmt.Sprintf("failed to add IP header: %s", err)) } if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } mldStat.Increment() diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 1d8fee50b..41112a0c4 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -731,7 +731,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add Data: buffer.View(icmp).ToVectorisedView(), }) - sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + sent := ndp.ep.stats.icmp.packetsSent if err := ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, @@ -740,10 +740,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } - sent.NeighborSolicit.Increment() + sent.neighborSolicit.Increment() + return nil } @@ -1855,7 +1856,7 @@ func (ndp *ndpState) startSolicitingRouters() { Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + sent := ndp.ep.stats.icmp.packetsSent if err := ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, @@ -1863,12 +1864,12 @@ func (ndp *ndpState) startSolicitingRouters() { panic(fmt.Sprintf("failed to add IP header: %s", err)) } if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. remaining = 0 } else { - sent.RouterSolicit.Increment() + sent.routerSolicit.Increment() remaining-- } diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go new file mode 100644 index 000000000..a2f2f4f78 --- /dev/null +++ b/pkg/tcpip/network/ipv6/stats.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv6 protocol family. +type Stats struct { + // IP holds IPv6 statistics. + IP tcpip.IPStats + + // ICMP holds ICMPv6 statistics. + ICMP tcpip.ICMPv6Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (s *Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv6Stats +} + +// LINT.IfChange(multiCounterICMPv6PacketStats) + +type multiCounterICMPv6PacketStats struct { + echoRequest tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + packetTooBig tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + routerSolicit tcpip.MultiCounterStat + routerAdvert tcpip.MultiCounterStat + neighborSolicit tcpip.MultiCounterStat + neighborAdvert tcpip.MultiCounterStat + redirectMsg tcpip.MultiCounterStat + multicastListenerQuery tcpip.MultiCounterStat + multicastListenerReport tcpip.MultiCounterStat + multicastListenerDone tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6PacketStats) init(a, b *tcpip.ICMPv6PacketStats) { + m.echoRequest.Init(a.EchoRequest, b.EchoRequest) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.packetTooBig.Init(a.PacketTooBig, b.PacketTooBig) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.routerSolicit.Init(a.RouterSolicit, b.RouterSolicit) + m.routerAdvert.Init(a.RouterAdvert, b.RouterAdvert) + m.neighborSolicit.Init(a.NeighborSolicit, b.NeighborSolicit) + m.neighborAdvert.Init(a.NeighborAdvert, b.NeighborAdvert) + m.redirectMsg.Init(a.RedirectMsg, b.RedirectMsg) + m.multicastListenerQuery.Init(a.MulticastListenerQuery, b.MulticastListenerQuery) + m.multicastListenerReport.Init(a.MulticastListenerReport, b.MulticastListenerReport) + m.multicastListenerDone.Init(a.MulticastListenerDone, b.MulticastListenerDone) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6PacketStats) + +// LINT.IfChange(multiCounterICMPv6SentPacketStats) + +type multiCounterICMPv6SentPacketStats struct { + multiCounterICMPv6PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6SentPacketStats) init(a, b *tcpip.ICMPv6SentPacketStats) { + m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6SentPacketStats) + +// LINT.IfChange(multiCounterICMPv6ReceivedPacketStats) + +type multiCounterICMPv6ReceivedPacketStats struct { + multiCounterICMPv6PacketStats + unrecognized tcpip.MultiCounterStat + invalid tcpip.MultiCounterStat + routerOnlyPacketsDroppedByHost tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv6ReceivedPacketStats) init(a, b *tcpip.ICMPv6ReceivedPacketStats) { + m.multiCounterICMPv6PacketStats.init(&a.ICMPv6PacketStats, &b.ICMPv6PacketStats) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) + m.invalid.Init(a.Invalid, b.Invalid) + m.routerOnlyPacketsDroppedByHost.Init(a.RouterOnlyPacketsDroppedByHost, b.RouterOnlyPacketsDroppedByHost) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv6Stats) + +type multiCounterICMPv6Stats struct { + packetsSent multiCounterICMPv6SentPacketStats + packetsReceived multiCounterICMPv6ReceivedPacketStats +} + +func (m *multiCounterICMPv6Stats) init(a, b *tcpip.ICMPv6Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv6Stats) diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD index d0ffc299a..652b92a21 100644 --- a/pkg/tcpip/network/testutil/BUILD +++ b/pkg/tcpip/network/testutil/BUILD @@ -6,6 +6,7 @@ go_library( name = "testutil", srcs = [ "testutil.go", + "testutil_unsafe.go", ], visibility = [ "//pkg/tcpip/network/fragmentation:__pkg__", diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 3af44991f..9bd009374 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -19,6 +19,8 @@ package testutil import ( "fmt" "math/rand" + "reflect" + "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -127,3 +129,69 @@ func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSi } return pkt } + +func checkFieldCounts(ref, multi reflect.Value) error { + refTypeName := ref.Type().Name() + multiTypeName := multi.Type().Name() + refNumField := ref.NumField() + multiNumField := multi.NumField() + + if refNumField != multiNumField { + return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) + } + + return nil +} + +func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { + s, ok := ref.Addr().Interface().(**tcpip.StatCounter) + if !ok { + return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) + } + + // The field names are expected to match (case insensitive). + if !strings.EqualFold(refName, multiName) { + return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) + } + + base := (*s).Value() + m.Increment() + if (*s).Value() != base+1 { + return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) + } + + return nil +} + +// ValidateMultiCounterStats verifies that every counter stored in multi is +// correctly tracking its counterpart in the given counters. +func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { + for _, c := range counters { + if err := checkFieldCounts(c, multi); err != nil { + return err + } + } + + for i := 0; i < multi.NumField(); i++ { + multiName := multi.Type().Field(i).Name + multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) + + if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { + for _, c := range counters { + if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { + return err + } + } + } else { + var countersNextField []reflect.Value + for _, c := range counters { + countersNextField = append(countersNextField, c.Field(i)) + } + if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/tcpip/network/testutil/testutil_unsafe.go b/pkg/tcpip/network/testutil/testutil_unsafe.go new file mode 100644 index 000000000..5ff764800 --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil_unsafe.go @@ -0,0 +1,26 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "reflect" + "unsafe" +) + +// unsafeExposeUnexportedFields takes a Value and returns a version of it in +// which even unexported fields can be read and written. +func unsafeExposeUnexportedFields(a reflect.Value) reflect.Value { + return reflect.NewAt(a.Type(), unsafe.Pointer(a.UnsafeAddr())).Elem() +} |