summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network/ipv4
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r--pkg/tcpip/network/ipv4/BUILD13
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go69
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go33
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go115
-rw-r--r--pkg/tcpip/network/ipv4/stats.go190
-rw-r--r--pkg/tcpip/network/ipv4/stats_test.go101
6 files changed, 436 insertions, 85 deletions
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 32f53f217..330a7d170 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -8,6 +8,7 @@ go_library(
"icmp.go",
"igmp.go",
"ipv4.go",
+ "stats.go",
],
visibility = ["//visibility:public"],
deps = [
@@ -49,3 +50,15 @@ go_test(
"@com_github_google_go_cmp//cmp:go_default_library",
],
)
+
+go_test(
+ name = "stats_test",
+ size = "small",
+ srcs = ["stats_test.go"],
+ library = ":ipv4",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/network/testutil",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 8e392f86c..3f60de749 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -62,21 +62,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
}
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- received := stats.ICMP.V4.PacketsReceived
+ received := e.stats.icmp.packetsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
// full explanation.
v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.ICMPv4(v)
// Only do in-stack processing if the checksum is correct.
if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
- received.Invalid.Increment()
+ received.invalid.Increment()
// It's possible that a raw socket expects to receive this regardless
// of checksum errors. If it's an echo request we know it's safe because
// we are the only handler, however other types do not cope well with
@@ -117,8 +116,8 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
errors.Is(err, errIPv4TimestampOptInvalidPointer),
errors.Is(err, errIPv4TimestampOptOverflow):
_ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
- stats.MalformedRcvdPackets.Increment()
- stats.IP.MalformedPacketsReceived.Increment()
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stats.ip.MalformedPacketsReceived.Increment()
}
return
}
@@ -128,11 +127,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
- received.Echo.Increment()
+ received.echo.Increment()
- sent := stats.ICMP.V4.PacketsSent
+ sent := e.stats.icmp.packetsSent
if !e.protocol.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return
}
@@ -213,18 +212,18 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return
}
- sent.EchoReply.Increment()
+ sent.echoReply.Increment()
case header.ICMPv4EchoReply:
- received.EchoReply.Increment()
+ received.echoReply.Increment()
e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
- received.DstUnreachable.Increment()
+ received.dstUnreachable.Increment()
pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
@@ -243,31 +242,31 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
}
case header.ICMPv4SrcQuench:
- received.SrcQuench.Increment()
+ received.srcQuench.Increment()
case header.ICMPv4Redirect:
- received.Redirect.Increment()
+ received.redirect.Increment()
case header.ICMPv4TimeExceeded:
- received.TimeExceeded.Increment()
+ received.timeExceeded.Increment()
case header.ICMPv4ParamProblem:
- received.ParamProblem.Increment()
+ received.paramProblem.Increment()
case header.ICMPv4Timestamp:
- received.Timestamp.Increment()
+ received.timestamp.Increment()
case header.ICMPv4TimestampReply:
- received.TimestampReply.Increment()
+ received.timestampReply.Increment()
case header.ICMPv4InfoRequest:
- received.InfoRequest.Increment()
+ received.infoRequest.Increment()
case header.ICMPv4InfoReply:
- received.InfoReply.Increment()
+ received.infoReply.Increment()
default:
- received.Invalid.Increment()
+ received.invalid.Increment()
}
}
@@ -379,9 +378,17 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
}
defer route.Release()
- sent := p.stack.Stats().ICMP.V4.PacketsSent
+ p.mu.Lock()
+ netEP, ok := p.mu.eps[pkt.NICID]
+ p.mu.Unlock()
+ if !ok {
+ return tcpip.ErrNotConnected
+ }
+
+ sent := netEP.stats.icmp.packetsSent
+
if !p.stack.AllowICMPMessage() {
- sent.RateLimited.Increment()
+ sent.rateLimited.Increment()
return nil
}
@@ -471,29 +478,29 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- var counter *tcpip.StatCounter
+ var counter tcpip.MultiCounterStat
switch reason := reason.(type) {
case *icmpReasonPortUnreachable:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4PortUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonProtoUnreachable:
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
- counter = sent.DstUnreachable
+ counter = sent.dstUnreachable
case *icmpReasonTTLExceeded:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4TTLExceeded)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonReassemblyTimeout:
icmpHdr.SetType(header.ICMPv4TimeExceeded)
icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
- counter = sent.TimeExceeded
+ counter = sent.timeExceeded
case *icmpReasonParamProblem:
icmpHdr.SetType(header.ICMPv4ParamProblem)
icmpHdr.SetCode(header.ICMPv4UnusedCode)
icmpHdr.SetPointer(reason.pointer)
- counter = sent.ParamProblem
+ counter = sent.paramProblem
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
@@ -508,7 +515,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi
},
icmpPkt,
); err != nil {
- sent.Dropped.Increment()
+ sent.dropped.Increment()
return err
}
counter.Increment()
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index d9b5fe6ed..9515fde45 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -149,11 +149,10 @@ func (igmp *igmpState) init(ep *endpoint) {
//
// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
- stats := igmp.ep.protocol.stack.Stats()
- received := stats.IGMP.PacketsReceived
+ received := igmp.ep.stats.igmp.packetsReceived
headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize)
if !ok {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
h := header.IGMP(headerView)
@@ -166,34 +165,34 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
h.SetChecksum(wantChecksum)
if gotChecksum != wantChecksum {
- received.ChecksumErrors.Increment()
+ received.checksumErrors.Increment()
return
}
switch h.Type() {
case header.IGMPMembershipQuery:
- received.MembershipQuery.Increment()
+ received.membershipQuery.Increment()
if len(headerView) < header.IGMPQueryMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime())
case header.IGMPv1MembershipReport:
- received.V1MembershipReport.Increment()
+ received.v1MembershipReport.Increment()
if len(headerView) < header.IGMPReportMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPv2MembershipReport:
- received.V2MembershipReport.Increment()
+ received.v2MembershipReport.Increment()
if len(headerView) < header.IGMPReportMinimumSize {
- received.Invalid.Increment()
+ received.invalid.Increment()
return
}
igmp.handleMembershipReport(h.GroupAddress())
case header.IGMPLeaveGroup:
- received.LeaveGroup.Increment()
+ received.leaveGroup.Increment()
// As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or
// Report, are ignored in all states"
@@ -201,7 +200,7 @@ func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
// As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should
// be silently ignored. New message types may be used by newer versions of
// IGMP, by multicast routing protocols, or other uses."
- received.Unrecognized.Increment()
+ received.unrecognized.Increment()
}
}
@@ -272,18 +271,18 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent
+ sentStats := igmp.ep.stats.igmp.packetsSent
if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
- sentStats.Dropped.Increment()
+ sentStats.dropped.Increment()
return false, err
}
switch igmpType {
case header.IGMPv1MembershipReport:
- sentStats.V1MembershipReport.Increment()
+ sentStats.v1MembershipReport.Increment()
case header.IGMPv2MembershipReport:
- sentStats.V2MembershipReport.Increment()
+ sentStats.v2MembershipReport.Increment()
case header.IGMPLeaveGroup:
- sentStats.LeaveGroup.Increment()
+ sentStats.leaveGroup.Increment()
default:
panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType))
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index bb25a76fe..7f03696ae 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"math"
+ "reflect"
"sync/atomic"
"time"
@@ -73,6 +74,7 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
+ stats sharedStats
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -114,9 +116,27 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
e.mu.addressableEndpointState.Init(e)
e.mu.igmp.init(e)
e.mu.Unlock()
+
+ tcpip.InitStatCounters(reflect.ValueOf(&e.stats.localStats).Elem())
+
+ stackStats := p.stack.Stats()
+ e.stats.ip.Init(&e.stats.localStats.IP, &stackStats.IP)
+ e.stats.icmp.init(&e.stats.localStats.ICMP, &stackStats.ICMP.V4)
+ e.stats.igmp.init(&e.stats.localStats.IGMP, &stackStats.IGMP)
+
+ p.mu.Lock()
+ p.mu.eps[nic.ID()] = e
+ p.mu.Unlock()
+
return e
}
+func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ delete(p.mu.eps, nicID)
+}
+
// Enable implements stack.NetworkEndpoint.
func (e *endpoint) Enable() *tcpip.Error {
e.mu.Lock()
@@ -305,7 +325,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
- e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment()
+ e.stats.ip.IPTablesOutputDropped.Increment()
return nil
}
@@ -349,9 +369,11 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
+ stats := e.stats.ip
+
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
@@ -363,16 +385,16 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
// WritePackets(). It'll be faster but cost more memory.
return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
})
- r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ stats.PacketsSent.IncrementBy(uint64(sent))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+ stats.OutgoingPacketErrors.Increment()
return err
}
- r.Stats().IP.PacketsSent.Increment()
+ stats.PacketsSent.Increment()
return nil
}
@@ -385,6 +407,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ stats := e.stats.ip
+
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if err := e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */); err != nil {
return 0, err
@@ -392,7 +416,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
@@ -421,13 +445,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
}
return n, err
}
- r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
// Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
@@ -451,15 +475,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
}
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
+ stats.PacketsSent.IncrementBy(uint64(n))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
// Dropped packets aren't errors, so include them in
// the return value.
return n + len(dropped), err
}
n++
}
- r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ stats.PacketsSent.IncrementBy(uint64(n))
// Dropped packets aren't errors, so include them in the return value.
return n + len(dropped), nil
}
@@ -577,11 +601,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error {
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- stats := e.protocol.stack.Stats()
- stats.IP.PacketsReceived.Increment()
+ stats := e.stats.ip
+
+ stats.PacketsReceived.Increment()
if !e.isEnabled() {
- stats.IP.DisabledPacketsReceived.Increment()
+ stats.DisabledPacketsReceived.Increment()
return
}
@@ -589,7 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.nic.IsLoopback() {
if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesPreroutingDropped.Increment()
+ stats.IPTablesPreroutingDropped.Increment()
return
}
}
@@ -601,11 +626,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// iptables hook.
func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NICID = e.nic.ID()
- stats := e.protocol.stack.Stats()
+ stats := e.stats
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -631,7 +656,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- stats.IP.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
return
}
@@ -643,7 +668,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// be one of its own IP addresses (but not a broadcast or
// multicast address).
if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.ip.InvalidSourceAddressesReceived.Increment()
return
}
// Make sure the source address is not a subnet-local broadcast address.
@@ -651,7 +676,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
subnet := addressEndpoint.Subnet()
addressEndpoint.DecRef()
if subnet.IsBroadcast(srcAddr) {
- stats.IP.InvalidSourceAddressesReceived.Increment()
+ stats.ip.InvalidSourceAddressesReceived.Increment()
return
}
}
@@ -664,7 +689,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast
} else if !e.IsInGroup(dstAddr) {
if !e.protocol.Forwarding() {
- stats.IP.InvalidDestinationAddressesReceived.Increment()
+ stats.ip.InvalidDestinationAddressesReceived.Increment()
return
}
@@ -676,7 +701,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// this machine and will not be forwarded.
if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- stats.IP.IPTablesInputDropped.Increment()
+ stats.ip.IPTablesInputDropped.Increment()
return
}
@@ -684,8 +709,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -698,8 +723,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
@@ -720,8 +745,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
pkt,
)
if err != nil {
- stats.IP.MalformedPacketsReceived.Increment()
- stats.IP.MalformedFragmentsReceived.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
+ stats.ip.MalformedFragmentsReceived.Increment()
return
}
if !ready {
@@ -734,7 +759,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
h.SetFlagsFragmentOffset(0, 0)
}
- stats.IP.PacketsDelivered.Increment()
+ stats.ip.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
@@ -766,8 +791,8 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
errors.Is(err, errIPv4TimestampOptInvalidPointer),
errors.Is(err, errIPv4TimestampOptOverflow):
_ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
- stats.MalformedRcvdPackets.Increment()
- stats.IP.MalformedPacketsReceived.Increment()
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ stats.ip.MalformedPacketsReceived.Increment()
}
return
}
@@ -800,6 +825,8 @@ func (e *endpoint) Close() {
e.disableLocked()
e.mu.addressableEndpointState.Cleanup()
+
+ e.protocol.forgetEndpoint(e.nic.ID())
}
// AddAndAcquirePermanentAddress implements stack.AddressableEndpoint.
@@ -911,6 +938,11 @@ func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
return e.mu.igmp.isInGroup(addr)
}
+// Stats implements stack.NetworkEndpoint.
+func (e *endpoint) Stats() stack.NetworkEndpointStats {
+ return &e.stats.localStats
+}
+
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
var _ fragmentation.TimeoutHandler = (*protocol)(nil)
@@ -918,6 +950,14 @@ var _ fragmentation.TimeoutHandler = (*protocol)(nil)
type protocol struct {
stack *stack.Stack
+ mu struct {
+ sync.RWMutex
+
+ // eps is keyed by NICID to allow protocol methods to retrieve an endpoint
+ // when handling a packet, by looking at which NIC handled the packet.
+ eps map[tcpip.NICID]*endpoint
+ }
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -1095,6 +1135,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
options: opts,
}
p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p)
+ p.mu.eps = make(map[tcpip.NICID]*endpoint)
return p
}
}
@@ -1379,7 +1420,7 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad
// - If there is an error, information as to what it was was.
// - The replacement option set.
func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
- stats := e.protocol.stack.Stats()
+ stats := e.stats.ip
opts := header.IPv4Options(orig)
optIter := opts.MakeIterator()
@@ -1427,7 +1468,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
optLen := int(option.Size())
switch option := option.(type) {
case *header.IPv4OptionTimestamp:
- stats.IP.OptionTSReceived.Increment()
+ stats.OptionTSReceived.Increment()
if usage.actions().timestamp != optionRemove {
clock := e.protocol.stack.Clock()
newBuffer := optIter.RemainingBuffer()[:len(*option)]
@@ -1440,7 +1481,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
}
case *header.IPv4OptionRecordRoute:
- stats.IP.OptionRRReceived.Increment()
+ stats.OptionRRReceived.Increment()
if usage.actions().recordRoute != optionRemove {
newBuffer := optIter.RemainingBuffer()[:len(*option)]
_ = copy(newBuffer, option.Contents())
@@ -1452,7 +1493,7 @@ func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Opt
}
default:
- stats.IP.OptionUnknownReceived.Increment()
+ stats.OptionUnknownReceived.Increment()
if usage.actions().unknown == optionPass {
newBuffer := optIter.RemainingBuffer()[:optLen]
// Arguments already heavily checked.. ignore result.
diff --git a/pkg/tcpip/network/ipv4/stats.go b/pkg/tcpip/network/ipv4/stats.go
new file mode 100644
index 000000000..7620728f9
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/stats.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.IPNetworkEndpointStats = (*Stats)(nil)
+
+// Stats holds statistics related to the IPv4 protocol family.
+type Stats struct {
+ // IP holds IPv4 statistics.
+ IP tcpip.IPStats
+
+ // IGMP holds IGMP statistics.
+ IGMP tcpip.IGMPStats
+
+ // ICMP holds ICMPv4 statistics.
+ ICMP tcpip.ICMPv4Stats
+}
+
+// IsNetworkEndpointStats implements stack.NetworkEndpointStats.
+func (s *Stats) IsNetworkEndpointStats() {}
+
+// IPStats implements stack.IPNetworkEndointStats
+func (s *Stats) IPStats() *tcpip.IPStats {
+ return &s.IP
+}
+
+type sharedStats struct {
+ localStats Stats
+ ip ip.MultiCounterIPStats
+ icmp multiCounterICMPv4Stats
+ igmp multiCounterIGMPStats
+}
+
+// LINT.IfChange(multiCounterICMPv4PacketStats)
+
+type multiCounterICMPv4PacketStats struct {
+ echo tcpip.MultiCounterStat
+ echoReply tcpip.MultiCounterStat
+ dstUnreachable tcpip.MultiCounterStat
+ srcQuench tcpip.MultiCounterStat
+ redirect tcpip.MultiCounterStat
+ timeExceeded tcpip.MultiCounterStat
+ paramProblem tcpip.MultiCounterStat
+ timestamp tcpip.MultiCounterStat
+ timestampReply tcpip.MultiCounterStat
+ infoRequest tcpip.MultiCounterStat
+ infoReply tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4PacketStats) init(a, b *tcpip.ICMPv4PacketStats) {
+ m.echo.Init(a.Echo, b.Echo)
+ m.echoReply.Init(a.EchoReply, b.EchoReply)
+ m.dstUnreachable.Init(a.DstUnreachable, b.DstUnreachable)
+ m.srcQuench.Init(a.SrcQuench, b.SrcQuench)
+ m.redirect.Init(a.Redirect, b.Redirect)
+ m.timeExceeded.Init(a.TimeExceeded, b.TimeExceeded)
+ m.paramProblem.Init(a.ParamProblem, b.ParamProblem)
+ m.timestamp.Init(a.Timestamp, b.Timestamp)
+ m.timestampReply.Init(a.TimestampReply, b.TimestampReply)
+ m.infoRequest.Init(a.InfoRequest, b.InfoRequest)
+ m.infoReply.Init(a.InfoReply, b.InfoReply)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4PacketStats)
+
+// LINT.IfChange(multiCounterICMPv4SentPacketStats)
+
+type multiCounterICMPv4SentPacketStats struct {
+ multiCounterICMPv4PacketStats
+ dropped tcpip.MultiCounterStat
+ rateLimited tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4SentPacketStats) init(a, b *tcpip.ICMPv4SentPacketStats) {
+ m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+ m.rateLimited.Init(a.RateLimited, b.RateLimited)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4SentPacketStats)
+
+// LINT.IfChange(multiCounterICMPv4ReceivedPacketStats)
+
+type multiCounterICMPv4ReceivedPacketStats struct {
+ multiCounterICMPv4PacketStats
+ invalid tcpip.MultiCounterStat
+}
+
+func (m *multiCounterICMPv4ReceivedPacketStats) init(a, b *tcpip.ICMPv4ReceivedPacketStats) {
+ m.multiCounterICMPv4PacketStats.init(&a.ICMPv4PacketStats, &b.ICMPv4PacketStats)
+ m.invalid.Init(a.Invalid, b.Invalid)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4ReceivedPacketStats)
+
+// LINT.IfChange(multiCounterICMPv4Stats)
+
+type multiCounterICMPv4Stats struct {
+ packetsSent multiCounterICMPv4SentPacketStats
+ packetsReceived multiCounterICMPv4ReceivedPacketStats
+}
+
+func (m *multiCounterICMPv4Stats) init(a, b *tcpip.ICMPv4Stats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:ICMPv4Stats)
+
+// LINT.IfChange(multiCounterIGMPPacketStats)
+
+type multiCounterIGMPPacketStats struct {
+ membershipQuery tcpip.MultiCounterStat
+ v1MembershipReport tcpip.MultiCounterStat
+ v2MembershipReport tcpip.MultiCounterStat
+ leaveGroup tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPPacketStats) init(a, b *tcpip.IGMPPacketStats) {
+ m.membershipQuery.Init(a.MembershipQuery, b.MembershipQuery)
+ m.v1MembershipReport.Init(a.V1MembershipReport, b.V1MembershipReport)
+ m.v2MembershipReport.Init(a.V2MembershipReport, b.V2MembershipReport)
+ m.leaveGroup.Init(a.LeaveGroup, b.LeaveGroup)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPPacketStats)
+
+// LINT.IfChange(multiCounterIGMPSentPacketStats)
+
+type multiCounterIGMPSentPacketStats struct {
+ multiCounterIGMPPacketStats
+ dropped tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPSentPacketStats) init(a, b *tcpip.IGMPSentPacketStats) {
+ m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
+ m.dropped.Init(a.Dropped, b.Dropped)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPSentPacketStats)
+
+// LINT.IfChange(multiCounterIGMPReceivedPacketStats)
+
+type multiCounterIGMPReceivedPacketStats struct {
+ multiCounterIGMPPacketStats
+ invalid tcpip.MultiCounterStat
+ checksumErrors tcpip.MultiCounterStat
+ unrecognized tcpip.MultiCounterStat
+}
+
+func (m *multiCounterIGMPReceivedPacketStats) init(a, b *tcpip.IGMPReceivedPacketStats) {
+ m.multiCounterIGMPPacketStats.init(&a.IGMPPacketStats, &b.IGMPPacketStats)
+ m.invalid.Init(a.Invalid, b.Invalid)
+ m.checksumErrors.Init(a.ChecksumErrors, b.ChecksumErrors)
+ m.unrecognized.Init(a.Unrecognized, b.Unrecognized)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPReceivedPacketStats)
+
+// LINT.IfChange(multiCounterIGMPStats)
+
+type multiCounterIGMPStats struct {
+ packetsSent multiCounterIGMPSentPacketStats
+ packetsReceived multiCounterIGMPReceivedPacketStats
+}
+
+func (m *multiCounterIGMPStats) init(a, b *tcpip.IGMPStats) {
+ m.packetsSent.init(&a.PacketsSent, &b.PacketsSent)
+ m.packetsReceived.init(&a.PacketsReceived, &b.PacketsReceived)
+}
+
+// LINT.ThenChange(../../tcpip.go:IGMPStats)
diff --git a/pkg/tcpip/network/ipv4/stats_test.go b/pkg/tcpip/network/ipv4/stats_test.go
new file mode 100644
index 000000000..84641bcf4
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/stats_test.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.NetworkInterface
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func getKnownNICIDs(proto *protocol) []tcpip.NICID {
+ var nicIDs []tcpip.NICID
+
+ for k := range proto.mu.eps {
+ nicIDs = append(nicIDs, k)
+ }
+
+ return nicIDs
+}
+
+func TestClearEndpointFromProtocolOnClose(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ nic := testInterface{nicID: 1}
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ {
+ proto.mu.Lock()
+ foundEP, hasEP := proto.mu.eps[nic.ID()]
+ nicIDs := getKnownNICIDs(proto)
+ proto.mu.Unlock()
+
+ if !hasEP {
+ t.Fatalf("expected to find the nic id %d in the protocol's endpoint map (%v)", nic.ID(), nicIDs)
+ }
+ if foundEP != ep {
+ t.Fatalf("expected protocol to map endpoint %p to nic id %d, but endpoint %p was found instead", ep, nic.ID(), foundEP)
+ }
+ }
+
+ ep.Close()
+
+ {
+ proto.mu.Lock()
+ _, hasEP := proto.mu.eps[nic.ID()]
+ nicIDs := getKnownNICIDs(proto)
+ proto.mu.Unlock()
+ if hasEP {
+ t.Fatalf("unexpectedly found an endpoint mapped to the nic id %d in the protocol's known nic ids (%v)", nic.ID(), nicIDs)
+ }
+ }
+}
+
+func TestMultiCounterStatsInitialization(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
+ })
+ proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
+ var nic testInterface
+ ep := proto.NewEndpoint(&nic, nil, nil, nil).(*endpoint)
+ // At this point, the Stack's stats and the NetworkEndpoint's stats are
+ // expected to be bound by a MultiCounterStat.
+ refStack := s.Stats()
+ refEP := ep.stats.localStats
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.ip).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IP).Elem(), reflect.ValueOf(&refStack.IP).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.icmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.ICMP).Elem(), reflect.ValueOf(&refStack.ICMP.V4).Elem()}); err != nil {
+ t.Error(err)
+ }
+ if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&ep.stats.igmp).Elem(), []reflect.Value{reflect.ValueOf(&refEP.IGMP).Elem(), reflect.ValueOf(&refStack.IGMP).Elem()}); err != nil {
+ t.Error(err)
+ }
+}