diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 13 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 115 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/stats.go | 190 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/stats_test.go | 101 |
6 files changed, 436 insertions, 85 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 32f53f217..330a7d170 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -8,6 +8,7 @@ go_library( "icmp.go", "igmp.go", "ipv4.go", + "stats.go", ], visibility = ["//visibility:public"], deps = [ @@ -49,3 +50,15 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.go"], + library = ":ipv4", + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 8e392f86c..3f60de749 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -62,21 +62,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - received := stats.ICMP.V4.PacketsReceived + received := e.stats.icmp.packetsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.ICMPv4(v) // Only do in-stack processing if the checksum is correct. if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { - received.Invalid.Increment() + received.invalid.Increment() // It's possible that a raw socket expects to receive this regardless // of checksum errors. If it's an echo request we know it's safe because // we are the only handler, however other types do not cope well with @@ -117,8 +116,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + e.stats.ip.MalformedPacketsReceived.Increment() } return } @@ -128,11 +127,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: - received.Echo.Increment() + received.echo.Increment() - sent := stats.ICMP.V4.PacketsSent + sent := e.stats.icmp.packetsSent if !e.protocol.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return } @@ -213,18 +212,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return } - sent.EchoReply.Increment() + sent.echoReply.Increment() case header.ICMPv4EchoReply: - received.EchoReply.Increment() + received.echoReply.Increment() e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: - received.DstUnreachable.Increment() + received.dstUnreachable.Increment() pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { @@ -243,31 +242,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } case header.ICMPv4SrcQuench: - received.SrcQuench.Increment() + received.srcQuench.Increment() case header.ICMPv4Redirect: - received.Redirect.Increment() + received.redirect.Increment() case header.ICMPv4TimeExceeded: - received.TimeExceeded.Increment() + received.timeExceeded.Increment() case header.ICMPv4ParamProblem: - received.ParamProblem.Increment() + received.paramProblem.Increment() case header.ICMPv4Timestamp: - received.Timestamp.Increment() + received.timestamp.Increment() case header.ICMPv4TimestampReply: - received.TimestampReply.Increment() + received.timestampReply.Increment() case header.ICMPv4InfoRequest: - received.InfoRequest.Increment() + received.infoRequest.Increment() case header.ICMPv4InfoReply: - received.InfoReply.Increment() + received.infoReply.Increment() default: - received.Invalid.Increment() + received.invalid.Increment() } } @@ -379,9 +378,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } defer route.Release() - sent := p.stack.Stats().ICMP.V4.PacketsSent + p.mu.Lock() + netEP, ok := p.mu.eps[pkt.NICID] + p.mu.Unlock() + if !ok { + return tcpip.ErrNotConnected + } + + sent := netEP.stats.icmp.packetsSent + if !p.stack.AllowICMPMessage() { - sent.RateLimited.Increment() + sent.rateLimited.Increment() return nil } @@ -471,29 +478,29 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter *tcpip.StatCounter + var counter tcpip.MultiCounterStat switch reason := reason.(type) { case *icmpReasonPortUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonProtoUnreachable: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.DstUnreachable + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.TimeExceeded + counter = sent.timeExceeded case *icmpReasonParamProblem: icmpHdr.SetType(header.ICMPv4ParamProblem) icmpHdr.SetCode(header.ICMPv4UnusedCode) icmpHdr.SetPointer(reason.pointer) - counter = sent.ParamProblem + counter = sent.paramProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } @@ -508,7 +515,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi }, icmpPkt, ); err != nil { - sent.Dropped.Increment() + sent.dropped.Increment() return err } counter.Increment() diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index d9b5fe6ed..9515fde45 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -149,11 +149,10 @@ func (igmp *igmpState) init(ep *endpoint) { // // Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { - stats := igmp.ep.protocol.stack.Stats() - received := stats.IGMP.PacketsReceived + received := igmp.ep.stats.igmp.packetsReceived headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) if !ok { - received.Invalid.Increment() + received.invalid.Increment() return } h := header.IGMP(headerView) @@ -166,34 +165,34 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { h.SetChecksum(wantChecksum) if gotChecksum != wantChecksum { - received.ChecksumErrors.Increment() + received.checksumErrors.Increment() return } switch h.Type() { case header.IGMPMembershipQuery: - received.MembershipQuery.Increment() + received.membershipQuery.Increment() if len(headerView) < header.IGMPQueryMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) case header.IGMPv1MembershipReport: - received.V1MembershipReport.Increment() + received.v1MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPv2MembershipReport: - received.V2MembershipReport.Increment() + received.v2MembershipReport.Increment() if len(headerView) < header.IGMPReportMinimumSize { - received.Invalid.Increment() + received.invalid.Increment() return } igmp.handleMembershipReport(h.GroupAddress()) case header.IGMPLeaveGroup: - received.LeaveGroup.Increment() + received.leaveGroup.Increment() // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or // Report, are ignored in all states" @@ -201,7 +200,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should // be silently ignored. New message types may be used by newer versions of // IGMP, by multicast routing protocols, or other uses." - received.Unrecognized.Increment() + received.unrecognized.Increment() } } @@ -272,18 +271,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip panic(fmt.Sprintf("failed to add IP header: %s", err)) } - sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.stats.igmp.packetsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sentStats.Dropped.Increment() + sentStats.dropped.Increment() return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sentStats.V1MembershipReport.Increment() + sentStats.v1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sentStats.V2MembershipReport.Increment() + sentStats.v2MembershipReport.Increment() case header.IGMPLeaveGroup: - sentStats.LeaveGroup.Increment() + sentStats.leaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bb25a76fe..7f03696ae 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "math" + "reflect" "sync/atomic" "time" @@ -73,6 +74,7 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol + stats sharedStats // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -114,9 +116,27 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa e.mu.addressableEndpointState.Init(e) e.mu.igmp.init(e) e.mu.Unlock() + + tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem()) + + stackStats := p.stack.Stats() + e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP) + e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4) + e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP) + + p.mu.Lock() + p.mu.eps[nic.ID()] = e + p.mu.Unlock() + return e } +func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.mu.eps, nicID) +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() *tcpip.Error { e.mu.Lock() @@ -305,7 +325,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() + e.stats.ip.IPTablesOutputDropped.Increment() return nil } @@ -349,9 +369,11 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + stats := e.stats.ip + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } @@ -363,16 +385,16 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // WritePackets(). It'll be faster but cost more memory. return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) }) - r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + stats.PacketsSent.IncrementBy(uint64(sent)) + stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() + stats.OutgoingPacketErrors.Increment() return err } - r.Stats().IP.PacketsSent.Increment() + stats.PacketsSent.Increment() return nil } @@ -385,6 +407,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + stats := e.stats.ip + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil { return 0, err @@ -392,7 +416,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } @@ -421,13 +445,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } return n, err } - r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. @@ -451,15 +475,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) + stats.PacketsSent.IncrementBy(uint64(n)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) // Dropped packets aren't errors, so include them in // the return value. return n + len(dropped), err } n++ } - r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + stats.PacketsSent.IncrementBy(uint64(n)) // Dropped packets aren't errors, so include them in the return value. return n + len(dropped), nil } @@ -577,11 +601,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - stats := e.protocol.stack.Stats() - stats.IP.PacketsReceived.Increment() + stats := e.stats.ip + + stats.PacketsReceived.Increment() if !e.isEnabled() { - stats.IP.DisabledPacketsReceived.Increment() + stats.DisabledPacketsReceived.Increment() return } @@ -589,7 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.nic.IsLoopback() { if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesPreroutingDropped.Increment() + stats.IPTablesPreroutingDropped.Increment() return } } @@ -601,11 +626,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables hook. func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() - stats := e.protocol.stack.Stats() + stats := e.stats h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -631,7 +656,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - stats.IP.MalformedPacketsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() return } @@ -643,7 +668,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // be one of its own IP addresses (but not a broadcast or // multicast address). if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } // Make sure the source address is not a subnet-local broadcast address. @@ -651,7 +676,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { subnet := addressEndpoint.Subnet() addressEndpoint.DecRef() if subnet.IsBroadcast(srcAddr) { - stats.IP.InvalidSourceAddressesReceived.Increment() + stats.ip.InvalidSourceAddressesReceived.Increment() return } } @@ -664,7 +689,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { if !e.protocol.Forwarding() { - stats.IP.InvalidDestinationAddressesReceived.Increment() + stats.ip.InvalidDestinationAddressesReceived.Increment() return } @@ -676,7 +701,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // this machine and will not be forwarded. if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - stats.IP.IPTablesInputDropped.Increment() + stats.ip.IPTablesInputDropped.Increment() return } @@ -684,8 +709,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -698,8 +723,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } @@ -720,8 +745,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt, ) if err != nil { - stats.IP.MalformedPacketsReceived.Increment() - stats.IP.MalformedFragmentsReceived.Increment() + stats.ip.MalformedPacketsReceived.Increment() + stats.ip.MalformedFragmentsReceived.Increment() return } if !ready { @@ -734,7 +759,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } - stats.IP.PacketsDelivered.Increment() + stats.ip.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { @@ -766,8 +791,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) - stats.MalformedRcvdPackets.Increment() - stats.IP.MalformedPacketsReceived.Increment() + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() } return } @@ -800,6 +825,8 @@ func (e *endpoint) Close() { e.disableLocked() e.mu.addressableEndpointState.Cleanup() + + e.protocol.forgetEndpoint(e.nic.ID()) } // AddAndAcquirePermanentAddress implements stack.AddressableEndpoint. @@ -911,6 +938,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool { return e.mu.igmp.isInGroup(addr) } +// Stats implements stack.NetworkEndpoint. +func (e *endpoint) Stats() stack.NetworkEndpointStats { + return &e.stats.localStats +} + var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -918,6 +950,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack + mu struct { + sync.RWMutex + + // eps is keyed by NICID to allow protocol methods to retrieve an endpoint + // when handling a packet, by looking at which NIC handled the packet. + eps map[tcpip.NICID]*endpoint + } + // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful. // @@ -1095,6 +1135,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { options: opts, } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + p.mu.eps = make(map[tcpip.NICID]*endpoint) return p } } @@ -1379,7 +1420,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - If there is an error, information as to what it was was. // - The replacement option set. func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { - stats := e.protocol.stack.Stats() + stats := e.stats.ip opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1427,7 +1468,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt optLen := int(option.Size()) switch option := option.(type) { case *header.IPv4OptionTimestamp: - stats.IP.OptionTSReceived.Increment() + stats.OptionTSReceived.Increment() if usage.actions().timestamp != optionRemove { clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] @@ -1440,7 +1481,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } case *header.IPv4OptionRecordRoute: - stats.IP.OptionRRReceived.Increment() + stats.OptionRRReceived.Increment() if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) @@ -1452,7 +1493,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt } default: - stats.IP.OptionUnknownReceived.Increment() + stats.OptionUnknownReceived.Increment() if usage.actions().unknown == optionPass { newBuffer := optIter.RemainingBuffer()[:optLen] // Arguments already heavily checked.. ignore result. diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go new file mode 100644 index 000000000..7620728f9 --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats.go @@ -0,0 +1,190 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.IPNetworkEndpointStats = (*Stats)(nil) + +// Stats holds statistics related to the IPv4 protocol family. +type Stats struct { + // IP holds IPv4 statistics. + IP tcpip.IPStats + + // IGMP holds IGMP statistics. + IGMP tcpip.IGMPStats + + // ICMP holds ICMPv4 statistics. + ICMP tcpip.ICMPv4Stats +} + +// IsNetworkEndpointStats implements stack.NetworkEndpointStats. +func (s *Stats) IsNetworkEndpointStats() {} + +// IPStats implements stack.IPNetworkEndointStats +func (s *Stats) IPStats() *tcpip.IPStats { + return &s.IP +} + +type sharedStats struct { + localStats Stats + ip ip.MultiCounterIPStats + icmp multiCounterICMPv4Stats + igmp multiCounterIGMPStats +} + +// LINT.IfChange(multiCounterICMPv4PacketStats) + +type multiCounterICMPv4PacketStats struct { + echo tcpip.MultiCounterStat + echoReply tcpip.MultiCounterStat + dstUnreachable tcpip.MultiCounterStat + srcQuench tcpip.MultiCounterStat + redirect tcpip.MultiCounterStat + timeExceeded tcpip.MultiCounterStat + paramProblem tcpip.MultiCounterStat + timestamp tcpip.MultiCounterStat + timestampReply tcpip.MultiCounterStat + infoRequest tcpip.MultiCounterStat + infoReply tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) { + m.echo.Init(a.Echo, b.Echo) + m.echoReply.Init(a.EchoReply, b.EchoReply) + m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable) + m.srcQuench.Init(a.SrcQuench, b.SrcQuench) + m.redirect.Init(a.Redirect, b.Redirect) + m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded) + m.paramProblem.Init(a.ParamProblem, b.ParamProblem) + m.timestamp.Init(a.Timestamp, b.Timestamp) + m.timestampReply.Init(a.TimestampReply, b.TimestampReply) + m.infoRequest.Init(a.InfoRequest, b.InfoRequest) + m.infoReply.Init(a.InfoReply, b.InfoReply) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats) + +// LINT.IfChange(multiCounterICMPv4SentPacketStats) + +type multiCounterICMPv4SentPacketStats struct { + multiCounterICMPv4PacketStats + dropped tcpip.MultiCounterStat + rateLimited tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.dropped.Init(a.Dropped, b.Dropped) + m.rateLimited.Init(a.RateLimited, b.RateLimited) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats) + +// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats) + +type multiCounterICMPv4ReceivedPacketStats struct { + multiCounterICMPv4PacketStats + invalid tcpip.MultiCounterStat +} + +func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) { + m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats) + m.invalid.Init(a.Invalid, b.Invalid) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats) + +// LINT.IfChange(multiCounterICMPv4Stats) + +type multiCounterICMPv4Stats struct { + packetsSent multiCounterICMPv4SentPacketStats + packetsReceived multiCounterICMPv4ReceivedPacketStats +} + +func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:ICMPv4Stats) + +// LINT.IfChange(multiCounterIGMPPacketStats) + +type multiCounterIGMPPacketStats struct { + membershipQuery tcpip.MultiCounterStat + v1MembershipReport tcpip.MultiCounterStat + v2MembershipReport tcpip.MultiCounterStat + leaveGroup tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) { + m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery) + m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport) + m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport) + m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup) +} + +// LINT.ThenChange(../../tcpip.go:IGMPPacketStats) + +// LINT.IfChange(multiCounterIGMPSentPacketStats) + +type multiCounterIGMPSentPacketStats struct { + multiCounterIGMPPacketStats + dropped tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.dropped.Init(a.Dropped, b.Dropped) +} + +// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats) + +// LINT.IfChange(multiCounterIGMPReceivedPacketStats) + +type multiCounterIGMPReceivedPacketStats struct { + multiCounterIGMPPacketStats + invalid tcpip.MultiCounterStat + checksumErrors tcpip.MultiCounterStat + unrecognized tcpip.MultiCounterStat +} + +func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) { + m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats) + m.invalid.Init(a.Invalid, b.Invalid) + m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors) + m.unrecognized.Init(a.Unrecognized, b.Unrecognized) +} + +// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats) + +// LINT.IfChange(multiCounterIGMPStats) + +type multiCounterIGMPStats struct { + packetsSent multiCounterIGMPSentPacketStats + packetsReceived multiCounterIGMPReceivedPacketStats +} + +func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) { + m.packetsSent.init(&a.PacketsSent, &b.PacketsSent) + m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived) +} + +// LINT.ThenChange(../../tcpip.go:IGMPStats) diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go new file mode 100644 index 000000000..84641bcf4 --- /dev/null +++ b/pkg/tcpip/network/ipv4/stats_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.NetworkInterface + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func getKnownNICIDs(proto *protocol) []tcpip.NICID { + var nicIDs []tcpip.NICID + + for k := range proto.mu.eps { + nicIDs = append(nicIDs, k) + } + + return nicIDs +} + +func TestClearEndpointFromProtocolOnClose(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + nic := testInterface{nicID: 1} + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + { + proto.mu.Lock() + foundEP, hasEP := proto.mu.eps[nic.ID()] + nicIDs := getKnownNICIDs(proto) + proto.mu.Unlock() + + if !hasEP { + t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs) + } + if foundEP != ep { + t.Fatalf("expected protocol to map endpoint %p to nic id %d, but endpoint %p was found instead", ep, nic.ID(), foundEP) + } + } + + ep.Close() + + { + proto.mu.Lock() + _, hasEP := proto.mu.eps[nic.ID()] + nicIDs := getKnownNICIDs(proto) + proto.mu.Unlock() + if hasEP { + t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs) + } + } +} + +func TestMultiCounterStatsInitialization(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + }) + proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol) + var nic testInterface + ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint) + // At this point, the Stack's stats and the NetworkEndpoint's stats are + // expected to be bound by a MultiCounterStat. + refStack := s.Stats() + refEP := ep.stats.localStats + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil { + t.Error(err) + } + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil { + t.Error(err) + } +} |