diff options
-rw-r--r-- | pkg/tcpip/link/sniffer/sniffer.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/network/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 149 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 198 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 428 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 38 |
16 files changed, 703 insertions, 225 deletions
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 3836cebb7..414a7edb4 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -118,7 +118,7 @@ func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcp // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - LogPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, vv.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { vs := vv.Views() @@ -190,7 +190,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - LogPacket("send", protocol, hdr.UsedBytes(), payload) + logPacket("send", protocol, hdr.UsedBytes(), payload) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { hdrBuf := hdr.UsedBytes() @@ -226,8 +226,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload return e.lower.WritePacket(r, hdr, payload, protocol) } -// LogPacket logs the given packet. -func LogPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byte) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byte) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -316,9 +315,19 @@ func LogPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byt case header.TCPProtocolNumber: transName = "tcp" tcp := header.TCP(b) + offset := int(tcp.DataOffset()) + if offset < header.TCPMinimumSize { + details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) + break + } + if offset > len(tcp) { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) + break + } + srcPort = tcp.SourcePort() dstPort = tcp.DestinationPort() - size -= uint16(tcp.DataOffset()) + size -= uint16(offset) // Initialize the TCP flags. flags := tcp.Flags() @@ -334,6 +343,7 @@ func LogPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byt } else { details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions()) } + default: log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto) return diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9a26b46c4..25a3c98b6 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -12,8 +12,11 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", ], ) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f1edebe27..4475a75cf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -20,9 +20,25 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" +) + +const ( + localIpv4Addr = "\x0a\x00\x00\x01" + remoteIpv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" ) // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. @@ -152,10 +168,38 @@ func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payloa return nil } +func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { + s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s.CreateNIC(1, loopback.New()) + s.AddAddress(1, ipv4.ProtocolNumber, local) + s.SetRouteTable([]tcpip.Route{{ + Destination: ipv4SubnetAddr, + Mask: ipv4SubnetMask, + Gateway: ipv4Gateway, + NIC: 1, + }}) + + return s.FindRoute(1, local, remote, ipv4.ProtocolNumber) +} + +func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { + s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s.CreateNIC(1, loopback.New()) + s.AddAddress(1, ipv6.ProtocolNumber, local) + s.SetRouteTable([]tcpip.Route{{ + Destination: ipv6SubnetAddr, + Mask: ipv6SubnetMask, + Gateway: ipv6Gateway, + NIC: 1, + }}) + + return s.FindRoute(1, local, remote, ipv6.ProtocolNumber) +} + func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, nil, &o) + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -171,13 +215,13 @@ func TestIPv4Send(t *testing.T) { // Issue the write. o.protocol = 123 - o.srcAddr = "\x0a\x00\x00\x01" - o.dstAddr = "\x0a\x00\x00\x02" + o.srcAddr = localIpv4Addr + o.dstAddr = remoteIpv4Addr o.contents = payload - r := stack.Route{ - RemoteAddress: o.dstAddr, - LocalAddress: o.srcAddr, + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) } if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -187,7 +231,7 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil) + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -200,8 +244,8 @@ func TestIPv4Receive(t *testing.T) { TotalLength: uint16(totalLen), TTL: 20, Protocol: 10, - SrcAddr: "\x0a\x00\x00\x02", - DstAddr: "\x0a\x00\x00\x01", + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, }) // Make payload be non-zero. @@ -211,13 +255,13 @@ func TestIPv4Receive(t *testing.T) { // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. o.protocol = 10 - o.srcAddr = "\x0a\x00\x00\x02" - o.dstAddr = "\x0a\x00\x00\x01" + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr o.contents = view[header.IPv4MinimumSize:totalLen] - r := stack.Route{ - LocalAddress: o.dstAddr, - RemoteAddress: o.srcAddr, + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) } var views [1]buffer.View vv := view.ToVectorisedView(views) @@ -248,7 +292,7 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, } r := stack.Route{ - LocalAddress: "\x0a\x00\x00\x01", + LocalAddress: localIpv4Addr, RemoteAddress: "\x0a\x00\x00\xbb", } for _, c := range cases { @@ -256,7 +300,7 @@ func TestIPv4ReceiveControl(t *testing.T) { var views [1]buffer.View o := testObject{t: t} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil) + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -273,7 +317,7 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: uint8(header.ICMPv4ProtocolNumber), SrcAddr: "\x0a\x00\x00\xbb", - DstAddr: "\x0a\x00\x00\x01", + DstAddr: localIpv4Addr, }) // Create the ICMP header. @@ -290,8 +334,8 @@ func TestIPv4ReceiveControl(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: c.fragmentOffset, - SrcAddr: "\x0a\x00\x00\x01", - DstAddr: "\x0a\x00\x00\x02", + SrcAddr: localIpv4Addr, + DstAddr: remoteIpv4Addr, }) // Make payload be non-zero. @@ -302,8 +346,8 @@ func TestIPv4ReceiveControl(t *testing.T) { // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. o.protocol = 10 - o.srcAddr = "\x0a\x00\x00\x02" - o.dstAddr = "\x0a\x00\x00\x01" + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr o.contents = view[dataOffset:] o.typ = c.expectedTyp o.extra = c.expectedExtra @@ -321,7 +365,7 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil) + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -337,8 +381,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { Protocol: 10, FragmentOffset: 0, Flags: header.IPv4FlagMoreFragments, - SrcAddr: "\x0a\x00\x00\x02", - DstAddr: "\x0a\x00\x00\x01", + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -353,8 +397,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { TTL: 20, Protocol: 10, FragmentOffset: 24, - SrcAddr: "\x0a\x00\x00\x02", - DstAddr: "\x0a\x00\x00\x01", + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, }) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -363,13 +407,13 @@ func TestIPv4FragmentationReceive(t *testing.T) { // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. o.protocol = 10 - o.srcAddr = "\x0a\x00\x00\x02" - o.dstAddr = "\x0a\x00\x00\x01" + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r := stack.Route{ - LocalAddress: o.dstAddr, - RemoteAddress: o.srcAddr, + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) } // Send first segment. @@ -393,7 +437,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, nil, &o) + ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -409,13 +453,13 @@ func TestIPv6Send(t *testing.T) { // Issue the write. o.protocol = 123 - o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + o.srcAddr = localIpv6Addr + o.dstAddr = remoteIpv6Addr o.contents = payload - r := stack.Route{ - RemoteAddress: o.dstAddr, - LocalAddress: o.srcAddr, + r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) } if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -425,7 +469,7 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil) + ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -437,8 +481,8 @@ func TestIPv6Receive(t *testing.T) { PayloadLength: uint16(totalLen - header.IPv6MinimumSize), NextHeader: 10, HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02", - DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + SrcAddr: remoteIpv6Addr, + DstAddr: localIpv6Addr, }) // Make payload be non-zero. @@ -448,14 +492,15 @@ func TestIPv6Receive(t *testing.T) { // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. o.protocol = 10 - o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + o.srcAddr = remoteIpv6Addr + o.dstAddr = localIpv6Addr o.contents = view[header.IPv6MinimumSize:totalLen] - r := stack.Route{ - LocalAddress: o.dstAddr, - RemoteAddress: o.srcAddr, + r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) } + var views [1]buffer.View vv := view.ToVectorisedView(views) ep.HandlePacket(&r, &vv) @@ -491,7 +536,7 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } r := stack.Route{ - LocalAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + LocalAddress: localIpv6Addr, RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", } for _, c := range cases { @@ -499,7 +544,7 @@ func TestIPv6ReceiveControl(t *testing.T) { var views [1]buffer.View o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil) + ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -519,7 +564,7 @@ func TestIPv6ReceiveControl(t *testing.T) { NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: 20, SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + DstAddr: localIpv6Addr, }) // Create the ICMP header. @@ -534,8 +579,8 @@ func TestIPv6ReceiveControl(t *testing.T) { PayloadLength: 100, NextHeader: 10, HopLimit: 20, - SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02", + SrcAddr: localIpv6Addr, + DstAddr: remoteIpv6Addr, }) // Build the fragmentation header if needed. @@ -558,8 +603,8 @@ func TestIPv6ReceiveControl(t *testing.T) { // Give packet to IPv6 endpoint, dispatcher will validate that // it's ok. o.protocol = 10 - o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + o.srcAddr = remoteIpv6Addr + o.dstAddr = localIpv6Addr o.contents = view[dataOffset:] o.typ = c.expectedTyp o.extra = c.expectedExtra diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e5db44b5d..e6b1f5128 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -121,6 +121,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) } @@ -153,6 +154,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { e.handleICMP(r, vv) return } + r.Stats().IP.PacketsDelivered.Increment() e.dispatcher.DeliverTransportPacket(r, p, vv) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 67d8cc670..f48d120c5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -97,6 +97,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload SrcAddr: tcpip.Address(e.address[:]), DstAddr: r.RemoteAddress, }) + r.Stats().IP.PacketsSent.Increment() return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) } @@ -118,6 +119,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { return } + r.Stats().IP.PacketsDelivered.Increment() e.dispatcher.DeliverTransportPacket(r, p, vv) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 284732874..c158e2005 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -22,6 +22,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/ilist" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" ) // NIC represents a "network interface card" to which the networking stack is @@ -286,6 +287,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.Lin return } + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { + n.stack.stats.IP.PacketsReceived.Increment() + } + if len(vv.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -330,7 +335,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.Lin } if ref == nil { - n.stack.stats.UnknownNetworkEndpointRcvdPackets.Increment() + n.stack.stats.IP.InvalidAddressesReceived.Increment() return } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 423f428df..63a20e031 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -70,6 +70,11 @@ func (r *Route) MaxHeaderLength() uint16 { return r.ref.ep.MaxHeaderLength() } +// Stats returns a mutable copy of current stats. +func (r *Route) Stats() tcpip.Stats { + return r.ref.nic.stack.Stats() +} + // PseudoHeaderChecksum forwards the call to the network endpoint's // implementation. func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 { @@ -125,7 +130,11 @@ func (r *Route) IsResolutionRequired() bool { // WritePacket writes the packet through the given route. func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { - return r.ref.ep.WritePacket(r, hdr, payload, protocol) + err := r.ref.ep.WritePacket(r, hdr, payload, protocol) + if err == tcpip.ErrNoRoute { + r.Stats().IP.OutgoingPacketErrors.Increment() + } + return err } // MTU returns the MTU of the underlying network endpoint. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index dbaa9c829..862afa693 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -19,6 +19,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" ) type protocolIDs struct { @@ -111,6 +112,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Fail if we didn't find one. if ep == nil { + // UDP packet could not be delivered to an unknown destination port. + if protocol == header.UDPProtocolNumber { + r.Stats().UDP.UnknownPortErrors.Increment() + } return false } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index e323aea8c..166d37004 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "reflect" "strconv" "strings" "sync" @@ -47,48 +48,56 @@ import ( // Note: to support save / restore, it is important that all tcpip errors have // distinct error messages. type Error struct { - string + msg string + + ignoreStats bool } // String implements fmt.Stringer.String. func (e *Error) String() string { - return e.string + return e.msg +} + +// IgnoreStats indicates whether this error type should be included in failure +// counts in tcpip.Stats structs. +func (e *Error) IgnoreStats() bool { + return e.ignoreStats } // Errors that can be returned by the network stack. var ( - ErrUnknownProtocol = &Error{"unknown protocol"} - ErrUnknownNICID = &Error{"unknown nic id"} - ErrUnknownProtocolOption = &Error{"unknown option for protocol"} - ErrDuplicateNICID = &Error{"duplicate nic id"} - ErrDuplicateAddress = &Error{"duplicate address"} - ErrNoRoute = &Error{"no route"} - ErrBadLinkEndpoint = &Error{"bad link layer endpoint"} - ErrAlreadyBound = &Error{"endpoint already bound"} - ErrInvalidEndpointState = &Error{"endpoint is in invalid state"} - ErrAlreadyConnecting = &Error{"endpoint is already connecting"} - ErrAlreadyConnected = &Error{"endpoint is already connected"} - ErrNoPortAvailable = &Error{"no ports are available"} - ErrPortInUse = &Error{"port is in use"} - ErrBadLocalAddress = &Error{"bad local address"} - ErrClosedForSend = &Error{"endpoint is closed for send"} - ErrClosedForReceive = &Error{"endpoint is closed for receive"} - ErrWouldBlock = &Error{"operation would block"} - ErrConnectionRefused = &Error{"connection was refused"} - ErrTimeout = &Error{"operation timed out"} - ErrAborted = &Error{"operation aborted"} - ErrConnectStarted = &Error{"connection attempt started"} - ErrDestinationRequired = &Error{"destination address is required"} - ErrNotSupported = &Error{"operation not supported"} - ErrQueueSizeNotSupported = &Error{"queue size querying not supported"} - ErrNotConnected = &Error{"endpoint not connected"} - ErrConnectionReset = &Error{"connection reset by peer"} - ErrConnectionAborted = &Error{"connection aborted"} - ErrNoSuchFile = &Error{"no such file"} - ErrInvalidOptionValue = &Error{"invalid option value specified"} - ErrNoLinkAddress = &Error{"no remote link address"} - ErrBadAddress = &Error{"bad address"} - ErrNetworkUnreachable = &Error{"network is unreachable"} + ErrUnknownProtocol = &Error{msg: "unknown protocol"} + ErrUnknownNICID = &Error{msg: "unknown nic id"} + ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} + ErrDuplicateNICID = &Error{msg: "duplicate nic id"} + ErrDuplicateAddress = &Error{msg: "duplicate address"} + ErrNoRoute = &Error{msg: "no route"} + ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} + ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} + ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} + ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} + ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} + ErrNoPortAvailable = &Error{msg: "no ports are available"} + ErrPortInUse = &Error{msg: "port is in use"} + ErrBadLocalAddress = &Error{msg: "bad local address"} + ErrClosedForSend = &Error{msg: "endpoint is closed for send"} + ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} + ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} + ErrConnectionRefused = &Error{msg: "connection was refused"} + ErrTimeout = &Error{msg: "operation timed out"} + ErrAborted = &Error{msg: "operation aborted"} + ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} + ErrDestinationRequired = &Error{msg: "destination address is required"} + ErrNotSupported = &Error{msg: "operation not supported"} + ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} + ErrNotConnected = &Error{msg: "endpoint not connected"} + ErrConnectionReset = &Error{msg: "connection reset by peer"} + ErrConnectionAborted = &Error{msg: "connection aborted"} + ErrNoSuchFile = &Error{msg: "no such file"} + ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} + ErrNoLinkAddress = &Error{msg: "no remote link address"} + ErrBadAddress = &Error{msg: "bad address"} + ErrNetworkUnreachable = &Error{msg: "network is unreachable"} ) // Errors related to Subnet @@ -473,7 +482,7 @@ type StatCounter struct { // Increment adds one to the counter. func (s *StatCounter) Increment() { - atomic.AddUint64(&s.count, 1) + s.IncrementBy(1) } // Value returns the current value of the counter. @@ -486,6 +495,82 @@ func (s *StatCounter) IncrementBy(v uint64) { atomic.AddUint64(&s.count, v) } +// IPStats collects IP-specific stats (both v4 and v6). +type IPStats struct { + // PacketsReceived is the total number of IP packets received from the link + // layer in nic.DeliverNetworkPacket. + PacketsReceived *StatCounter + + // InvalidAddressesReceived is the total number of IP packets received + // with an unknown or invalid destination address. + InvalidAddressesReceived *StatCounter + + // PacketsDelivered is the total number of incoming IP packets that + // are successfully delivered to the transport layer via HandlePacket. + PacketsDelivered *StatCounter + + // PacketsSent is the total number of IP packets sent via WritePacket. + PacketsSent *StatCounter + + // OutgoingPacketErrors is the total number of IP packets which failed + // to write to a link-layer endpoint. + OutgoingPacketErrors *StatCounter +} + +// TCPStats collects TCP-specific stats. +type TCPStats struct { + // ActiveConnectionOpenings is the number of connections opened successfully + // via Connect. + ActiveConnectionOpenings *StatCounter + + // PassiveConnectionOpenings is the number of connections opened + // successfully via Listen. + PassiveConnectionOpenings *StatCounter + + // FailedConnectionAttempts is the number of calls to Connect or Listen + // (active and passive openings, respectively) that end in an error. + FailedConnectionAttempts *StatCounter + + // ValidSegmentsReceived is the number of TCP segments received that the + // transport layer successfully parsed. + ValidSegmentsReceived *StatCounter + + // InvalidSegmentsReceived is the number of TCP segments received that + // the transport layer could not parse. + InvalidSegmentsReceived *StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent *StatCounter + + // ResetsSent is the number of TCP resets sent. + ResetsSent *StatCounter + + // ResetsReceived is the number of TCP resets received. + ResetsReceived *StatCounter +} + +// UDPStats collects UDP-specific stats. +type UDPStats struct { + // PacketsReceived is the number of UDP datagrams received via + // HandlePacket. + PacketsReceived *StatCounter + + // UnknownPortErrors is the number of incoming UDP datagrams dropped + // because they did not have a known destination port. + UnknownPortErrors *StatCounter + + // ReceiveBufferErrors is the number of incoming UDP datagrams dropped + // due to the receiving buffer being in an invalid state. + ReceiveBufferErrors *StatCounter + + // MalformedPacketsReceived is the number of incoming UDP datagrams + // dropped due to the UDP header being in a malformed state. + MalformedPacketsReceived *StatCounter + + // PacketsSent is the number of UDP datagrams sent via sendUDP. + PacketsSent *StatCounter +} + // Stats holds statistics about the networking stack. // // All fields are optional. @@ -494,33 +579,42 @@ type Stats struct { // stack that were for an unknown or unsupported protocol. UnknownProtocolRcvdPackets *StatCounter - // UnknownNetworkEndpointRcvdPackets is the number of packets received - // by the stack that were for a supported network protocol, but whose - // destination address didn't having a matching endpoint. - UnknownNetworkEndpointRcvdPackets *StatCounter - // MalformedRcvPackets is the number of packets received by the stack // that were deemed malformed. MalformedRcvdPackets *StatCounter // DroppedPackets is the number of packets dropped due to full queues. DroppedPackets *StatCounter + + // IP breaks out IP-specific stats (both v4 and v6). + IP IPStats + + // TCP breaks out TCP-specific stats. + TCP TCPStats + + // UDP breaks out UDP-specific stats. + UDP UDPStats +} + +func fillIn(v reflect.Value) { + for i := 0; i < v.NumField(); i++ { + v := v.Field(i) + switch v.Kind() { + case reflect.Ptr: + if s, ok := v.Addr().Interface().(**StatCounter); ok { + if *s == nil { + *s = &StatCounter{} + } + } + case reflect.Struct: + fillIn(v) + } + } } // FillIn returns a copy of s with nil fields initialized to new StatCounters. func (s Stats) FillIn() Stats { - if s.UnknownProtocolRcvdPackets == nil { - s.UnknownProtocolRcvdPackets = &StatCounter{} - } - if s.UnknownNetworkEndpointRcvdPackets == nil { - s.UnknownNetworkEndpointRcvdPackets = &StatCounter{} - } - if s.MalformedRcvdPackets == nil { - s.MalformedRcvdPackets = &StatCounter{} - } - if s.DroppedPackets == nil { - s.DroppedPackets = &StatCounter{} - } + fillIn(reflect.ValueOf(&s).Elem()) return s } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 58d7942f3..14282d399 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -604,6 +604,11 @@ func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffe tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) } + r.Stats().TCP.SegmentsSent.Increment() + if (flags & flagRst) != 0 { + r.Stats().TCP.ResetsSent.Increment() + } + return r.WritePacket(&hdr, data, ProtocolNumber) } @@ -641,6 +646,11 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, fla tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) } + r.Stats().TCP.SegmentsSent.Increment() + if (flags & flagRst) != 0 { + r.Stats().TCP.ResetsSent.Increment() + } + return r.WritePacket(&hdr, data, ProtocolNumber) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index bdcba39c6..cbbbbc084 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -829,9 +829,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // created (so no new handshaking is done); for stack-accepted connections not // yet accepted by the app, they are restored without running the main goroutine // here. -func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() connectingAddr := addr.Addr @@ -960,6 +965,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if run { e.workerRunning = true + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. } @@ -1032,9 +1038,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() // Allow the backlog to be adjusted if the endpoint is not shutting down. // When the endpoint shuts down, it sets workerCleanup to true, and from @@ -1075,6 +1086,7 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error { } e.workerRunning = true + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) @@ -1226,10 +1238,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv s := newSegment(r, id, vv) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() + e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() s.decRef() return } + e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + if (s.flags & flagRst) != 0 { + e.stack.Stats().TCP.ResetsReceived.Increment() + } + // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 74318c012..71d70a597 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -62,7 +62,7 @@ func TestGiveUpConnect(t *testing.T) { defer wq.EventUnregister(&waitEntry) if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) } // Close the connection, wait for completion. @@ -75,6 +75,187 @@ func TestGiveUpConnect(t *testing.T) { } } +func TestConnectIncrementActiveConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ActiveConnectionOpenings.Value() + 1 + + c.CreateConnected(789, 30000, nil) + if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.FailedConnectionAttempts.Value() + + c.CreateConnected(789, 30000, nil) + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep + want := stats.TCP.FailedConnectionAttempts.Value() + 1 + + if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { + t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + } + + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } +} + +func TestPassiveConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.PassiveConnectionOpenings.Value() + 1 + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if err := ep.Listen(1); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep + want := stats.TCP.FailedConnectionAttempts.Value() + 1 + + if err := ep.Listen(1); err != tcpip.ErrInvalidEndpointState { + t.Errorf("got ep.Listen(1) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } +} + +func TestTCPSegmentsSentIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + // SYN and ACK + want := stats.TCP.SegmentsSent.Value() + 2 + c.CreateConnected(789, 30000, nil) + + if got := stats.TCP.SegmentsSent.Value(); got != want { + t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) + } +} + +func TestTCPResetsSentIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + want := stats.TCP.SegmentsSent.Value() + 1 + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + // If the AckNum is not the increment of the last sequence number, a RST + // segment is sent back in response. + AckNum: c.IRS + 2, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.GetPacket() + if got := stats.TCP.ResetsSent.Value(); got != want { + t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) + } +} + +func TestTCPResetsReceivedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + ackNum := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(ackNum, rcvWnd, nil) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: c.IRS.Add(2), + AckNum: ackNum.Add(2), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) + + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + } +} + func TestActiveHandshake(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -159,7 +340,7 @@ func TestSimpleReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -182,11 +363,11 @@ func TestSimpleReceive(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } - if bytes.Compare(data, v) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, v) + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) } // Check that ACK is received. @@ -211,7 +392,7 @@ func TestOutOfOrderReceive(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } // Send second half of data first, with seqnum 3 ahead of expected. @@ -238,7 +419,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } // Send the first 3 bytes now. @@ -265,15 +446,15 @@ func TestOutOfOrderReceive(t *testing.T) { } continue } - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } read = append(read, v...) } // Check that we received the data in proper order. - if bytes.Compare(data, read) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, read) + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) } // Check that the whole data is acknowledged. @@ -296,7 +477,7 @@ func TestOutOfOrderFlood(t *testing.T) { c.CreateConnected(789, 30000, &opt) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } // Send 100 packets before the actual one that is expected. @@ -373,7 +554,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -438,7 +619,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } data := []byte{1, 2, 3} @@ -517,7 +698,7 @@ func TestFullWindowReceive(t *testing.T) { _, _, err := c.EP.Read(nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } // Fill up the window. @@ -552,11 +733,11 @@ func TestFullWindowReceive(t *testing.T) { // Receive data and check it. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } - if bytes.Compare(data, v) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, v) + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) } // Check that we get an ACK for the newly non-zero window. @@ -588,9 +769,8 @@ func TestNoWindowShrinking(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err := c.EP.Read(nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } // Send 3 bytes, check that the peer acknowledges them. @@ -654,14 +834,14 @@ func TestNoWindowShrinking(t *testing.T) { for len(read) < len(data) { v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } read = append(read, v...) } - if bytes.Compare(data, read) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, read) + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) } // Check that we get an ACK for the newly non-zero window, which is the @@ -688,7 +868,7 @@ func TestSimpleSend(t *testing.T) { copy(view, data) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that data is received. @@ -703,8 +883,8 @@ func TestSimpleSend(t *testing.T) { ), ) - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, p) + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) } // Acknowledge the data. @@ -730,7 +910,7 @@ func TestZeroWindowSend(t *testing.T) { _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) if err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Since the window is currently zero, check that no packet is received. @@ -758,8 +938,8 @@ func TestZeroWindowSend(t *testing.T) { ), ) - if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, p) + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) } // Acknowledge the data. @@ -790,7 +970,7 @@ func TestScaledWindowConnect(t *testing.T) { copy(view, data) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that data is received, and that advertised window is 0xbfff, @@ -823,7 +1003,7 @@ func TestNonScaledWindowConnect(t *testing.T) { copy(view, data) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that data is received, and that advertised window is 0xffff, @@ -896,7 +1076,7 @@ func TestScaledWindowAccept(t *testing.T) { copy(view, data) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that data is received, and that advertised window is 0xbfff, @@ -969,7 +1149,7 @@ func TestNonScaledWindowAccept(t *testing.T) { copy(view, data) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that data is received, and that advertised window is 0xffff, @@ -1057,7 +1237,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { // Read some data. An ack should be sent in response to that. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1084,7 +1264,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { copy(view, data) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that data is received in chunks. @@ -1105,8 +1285,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { ) pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcp.Payload(); bytes.Compare(pdata, p) != 0 { - t.Fatalf("Data is different: expected %v, got %v", pdata, p) + if p := tcp.Payload(); !bytes.Equal(pdata, p) { + t.Fatalf("got data = %v, want = %v", p, pdata) } bytesReceived += payloadLen var options []byte @@ -1325,9 +1505,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventOut) defer c.WQ.EventUnregister(&we) - err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) } // Receive SYN packet. @@ -1380,9 +1559,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { - t.Fatalf("Unexpected error when connecting: %v", err) + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) } case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -1439,8 +1617,6 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { switch _, _, err := c.EP.Read(nil); err { - case nil: - t.Fatalf("Unexpected success.") case tcpip.ErrWouldBlock: select { case <-ch: @@ -1450,7 +1626,7 @@ loop: case tcpip.ErrConnectionReset: break loop default: - t.Fatalf("Unexpected error: want %v, got %v", tcpip.ErrConnectionReset, err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) } } } @@ -1475,9 +1651,8 @@ func TestSendOnResetConnection(t *testing.T) { // Try to write. view := buffer.NewView(10) - _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) - if err != tcpip.ErrConnectionReset { - t.Fatalf("Unexpected error from Write: want %v, got %v", tcpip.ErrConnectionReset, err) + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) } } @@ -1489,7 +1664,7 @@ func TestFinImmediately(t *testing.T) { // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1532,7 +1707,7 @@ func TestFinRetransmit(t *testing.T) { // Shutdown immediately, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1587,7 +1762,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Write something out, and have it acknowledged. view := buffer.NewView(10) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } next := uint32(c.IRS) + 1 @@ -1613,7 +1788,7 @@ func TestFinWithNoPendingData(t *testing.T) { // Shutdown, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1660,7 +1835,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { view := buffer.NewView(10) for i := tcp.InitialCwnd; i > 0; i-- { if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } } @@ -1682,7 +1857,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // because the congestion window doesn't allow it. Wait until a // retransmit is received. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1746,7 +1921,7 @@ func TestFinWithPendingData(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. view := buffer.NewView(10) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } next := uint32(c.IRS) + 1 @@ -1772,7 +1947,7 @@ func TestFinWithPendingData(t *testing.T) { // Write new data, but don't acknowledge it. if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1788,7 +1963,7 @@ func TestFinWithPendingData(t *testing.T) { // Shutdown the connection, check that we do get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1833,7 +2008,7 @@ func TestFinWithPartialAck(t *testing.T) { // FIN from the test side. view := buffer.NewView(10) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } next := uint32(c.IRS) + 1 @@ -1870,7 +2045,7 @@ func TestFinWithPartialAck(t *testing.T) { // Write new data, but don't acknowledge it. if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1886,7 +2061,7 @@ func TestFinWithPartialAck(t *testing.T) { // Shutdown the connection, check that we do get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -1940,7 +2115,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } expected := tcp.InitialCwnd @@ -1982,7 +2157,7 @@ func TestCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Do slow start for a few iterations. @@ -2087,7 +2262,7 @@ func TestCubicCongestionAvoidance(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Do slow start for a few iterations. @@ -2157,18 +2332,18 @@ func TestCubicCongestionAvoidance(t *testing.T) { // If our estimate was correct there should be no more pending packets. // We attempt to read a packet a few times with a short sleep in between // to ensure that we don't see the sender send any unexpected packets. - packetsUnexpected := 0 + unexpectedPackets := 0 for { gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) if !gotPacket { break } bytesRead += maxPayload - packetsUnexpected++ + unexpectedPackets++ time.Sleep(1 * time.Millisecond) } - if packetsUnexpected != 0 { - t.Fatalf("received %d unexpected packets for iteration %d", packetsUnexpected, i) + if unexpectedPackets != 0 { + t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) } // Check we don't receive any more packets on this iteration. // The timeout can't be too high or we'll trigger a timeout. @@ -2195,7 +2370,7 @@ func TestFastRecovery(t *testing.T) { // Write all the data in one shot. Packets will only be written at the // MTU size though. if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Do slow start for a few iterations. @@ -2327,11 +2502,11 @@ func TestRetransmit(t *testing.T) { // MTU size though. half := data[:len(data)/2] if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } half = data[len(data)/2:] if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Do slow start for a few iterations. @@ -2429,7 +2604,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { // Send some data. Check that it's capped by the window size. view := buffer.NewView(65535) if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Unexpected error from Write: %v", err) + t.Fatalf("Write failed: %v", err) } // Check that only data that fits in the scaled window is sent. @@ -2458,6 +2633,52 @@ func TestScaledSendWindow(t *testing.T) { } } +func TestReceivedValidSegmentCountIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + c.CreateConnected(789, 30000, nil) + stats := c.Stack().Stats() + want := stats.TCP.ValidSegmentsReceived.Value() + 1 + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want) + } +} + +func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + c.CreateConnected(789, 30000, nil) + stats := c.Stack().Stats() + want := stats.TCP.InvalidSegmentsReceived.Value() + 1 + vv := c.BuildSegment(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + tcpbuf := vv.ByteSlice()[0][header.IPv4MinimumSize:] + // 12 is the TCP header data offset. + tcpbuf[12] = ((header.TCPMinimumSize - 1) / 4) << 4 + + c.SendSegment(&vv) + + if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { + t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want) + } +} + func TestReceivedSegmentQueuing(t *testing.T) { // This test sends 200 segments containing a few bytes each to an // endpoint and checks that they're all received and acknowledged by @@ -2519,12 +2740,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Unexpected error from Shutdown: %v", err) + t.Fatalf("Shutdown failed: %v", err) } checker.IPv4(t, c.GetPacket(), @@ -2572,32 +2793,32 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Unexpected error from Peek: %v", err) + t.Fatalf("Peek failed: %v", err) } peekBuf = peekBuf[:n] - if bytes.Compare(data, peekBuf) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, peekBuf) + if !bytes.Equal(data, peekBuf) { + t.Fatalf("got data = %v, want = %v", peekBuf, data) } // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Unexpected error from Read: %v", err) + t.Fatalf("Read failed: %v", err) } - if bytes.Compare(data, v) != 0 { - t.Fatalf("Data is different: expected %v, got %v", data, v) + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) } // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("Unexpected return from Read: got %v, want %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("Unexpected return from Peek: got %v, want %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) } } @@ -2635,9 +2856,8 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { t.Fatalf("Bind failed: %v", err) } - err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) } c.EP.Close() @@ -2658,8 +2878,7 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { t.Fatalf("Bind failed: %v", err) } - err = c.EP.Listen(10) - if err != nil { + if err := c.EP.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } c.EP.Close() @@ -2671,8 +2890,7 @@ func TestReusePort(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { t.Fatalf("Bind failed: %v", err) } - err = c.EP.Listen(10) - if err != nil { + if err := c.EP.Listen(10); err != nil { t.Fatalf("Listen failed: %v", err) } } @@ -2686,7 +2904,7 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } if int(s) != v { - t.Fatalf("Bad receive buffer size: want=%v, got=%v", v, s) + t.Fatalf("got receive buffer size = %v, want = %v", s, v) } } @@ -2699,7 +2917,7 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { } if int(s) != v { - t.Fatalf("Bad send buffer size: want=%v, got=%v", v, s) + t.Fatalf("got send buffer size = %v, want = %v", s, v) } } @@ -2840,14 +3058,12 @@ func TestSelfConnect(t *testing.T) { wq.EventRegister(&waitEntry, waiter.EventOut) defer wq.EventUnregister(&waitEntry) - err = ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) } <-notifyCh - err = ep.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { t.Fatalf("Connect failed: %v", err) } @@ -2855,7 +3071,7 @@ func TestSelfConnect(t *testing.T) { data := []byte{1, 2, 3} view := buffer.NewView(len(data)) copy(view, data) - if _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %v", err) } @@ -2874,8 +3090,8 @@ func TestSelfConnect(t *testing.T) { } } - if bytes.Compare(data, rd) != 0 { - t.Fatalf("Data is different: want=%v, got=%v", data, rd) + if !bytes.Equal(data, rd) { + t.Fatalf("got data = %v, want = %v", rd, data) } } @@ -2949,16 +3165,16 @@ func TestTCPEndpointProbe(t *testing.T) { // We don't do an extensive validation of every field but a // basic sanity test. if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { - t.Fatalf("unexpected LocalAddress got: %q, want: %q", got, want) + t.Fatalf("got LocalAddress: %q, want: %q", got, want) } if got, want := state.ID.LocalPort, c.Port; got != want { - t.Fatalf("unexpected LocalPort got: %d, want: %d", got, want) + t.Fatalf("got LocalPort: %d, want: %d", got, want) } if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { - t.Fatalf("unexpected RemoteAddress got: %q, want: %q", got, want) + t.Fatalf("got RemoteAddress: %q, want: %q", got, want) } if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { - t.Fatalf("unexpected RemotePort got: %d, want: %d", got, want) + t.Fatalf("got RemotePort: %d, want: %d", got, want) } invoked <- struct{}{} @@ -3008,7 +3224,7 @@ func TestSetCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } if got, want := cc, tc.cc; got != want { - t.Fatalf("unexpected value for congestion control got: %v, want: %v", got, want) + t.Fatalf("got congestion control: %v, want: %v", got, want) } }) } @@ -3026,7 +3242,7 @@ func TestAvailableCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("unexpected value for AvailableCongestionControlOption: got: %v, want: %v", got, want) + t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) } } @@ -3048,7 +3264,7 @@ func TestSetAvailableCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("unexpected value for available congestion control got: %v, want: %v", got, want) + t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) } } diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index a529d9e72..894ead507 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -267,8 +267,7 @@ func TestSegmentDropWhenTimestampMissing(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - stk := c.Stack() - droppedPacketsStat := stk.Stats().DroppedPackets + droppedPacketsStat := c.Stack().Stats().DroppedPackets droppedPackets := droppedPacketsStat.Value() data := []byte{1, 2, 3} // Save the sequence number as we will reset it later down diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6b5786140..c46af4b8b 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -295,9 +295,8 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt c.linkEP.Inject(ipv4.ProtocolNumber, &vv) } -// SendPacket builds and sends a TCP segment(with the provided payload & TCP -// headers) in an IPv4 packet via the link layer endpoint. -func (c *Context) SendPacket(payload []byte, h *Headers) { +// BuildSegment builds a TCP segment based on the given Headers and payload. +func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -340,6 +339,20 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // Inject packet. var views [1]buffer.View vv := buf.ToVectorisedView(views) + + return vv +} + +// SendSegment sends a TCP segment that has already been built and written to a +// buffer.VectorisedView. +func (c *Context) SendSegment(s *buffer.VectorisedView) { + c.linkEP.Inject(ipv4.ProtocolNumber, s) +} + +// SendPacket builds and sends a TCP segment(with the provided payload & TCP +// headers) in an IPv4 packet via the link layer endpoint. +func (c *Context) SendPacket(payload []byte, h *Headers) { + vv := c.BuildSegment(payload, h) c.linkEP.Inject(ipv4.ProtocolNumber, &vv) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b2d7f9779..6a12c2f08 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -327,7 +327,10 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc if err != nil { return 0, err } - sendUDP(route, v, e.id.LocalPort, dstPort) + + if err := sendUDP(route, v, e.id.LocalPort, dstPort); err != nil { + return 0, err + } return uintptr(len(v)), nil } @@ -447,6 +450,9 @@ func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) *tc udp.SetChecksum(^udp.CalculateChecksum(xsum, length)) } + // Track count of packets sent. + r.Stats().UDP.PacketsSent.Increment() + return r.WritePacket(&hdr, data, ProtocolNumber) } @@ -758,15 +764,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { // Malformed packet. + e.stack.Stats().UDP.MalformedPacketsReceived.Increment() return } vv.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() + e.stack.Stats().UDP.PacketsReceived.Increment() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() e.rcvMu.Unlock() return } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 7203d7705..c1c099900 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -661,3 +661,41 @@ func TestV4WriteOnConnected(t *testing.T) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } } + +func TestReadIncrementsPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + // Create IPv4 UDP endpoint + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + testV4Read(c) + + var want uint64 = 1 + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { + c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) + } +} + +func TestWriteIncrementsPacketsSent(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + testDualWrite(c) + + var want uint64 = 2 + if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { + c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) + } +} |