diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/link/fdbased/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 285 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 22 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 48 |
8 files changed, 337 insertions, 80 deletions
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index bddb1d0a2..735c28da1 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,7 +41,6 @@ package fdbased import ( "fmt" - "math" "sync/atomic" "golang.org/x/sys/unix" @@ -196,8 +195,12 @@ type Options struct { // option for an FD with a fanoutID already in use by another FD for a different // NIC will return an EINVAL. // +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +// // Must be accessed using atomic operations. -var fanoutID int32 = 0 +var fanoutID int32 = int32(unix.Getpid()) // New creates a new fd-based endpoint. // @@ -292,11 +295,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin } switch sa.(type) { case *unix.SockaddrLinklayer: - // See: PACKET_FANOUT_MAX in net/packet/internal.h - const packetFanoutMax = 1 << 16 - if fID > packetFanoutMax { - return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16) - } // Enable PACKET_FANOUT mode if the underlying socket is of type // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will // prevent gvisor from receiving fragmented packets and the host does the @@ -317,7 +315,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin // // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 const fanoutType = unix.PACKET_FANOUT_HASH - fanoutArg := int(fID) | fanoutType<<16 + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6bee55634..c99297a51 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -720,7 +720,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -836,7 +837,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -855,10 +856,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -920,8 +921,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 6103574f7..12763add6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -991,7 +991,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -1104,7 +1105,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -1123,10 +1124,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1175,8 +1176,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1627,8 +1627,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1669,8 +1669,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4ca702121..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 07ba2b837..f9ab7d0af 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) { // Make sure the packet is not dropped by the next rule. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr checker.ICMPv6Type(header.ICMPv6EchoReply))) } +func boolToInt(v bool) uint64 { + if v { + return 1 + } + return 0 +} + +func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[hook] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } +} + func TestForwardingHook(t *testing.T) { const ( nicID1 = 1 @@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) { }, } - setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[stack.Forward] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } - } - - boolToInt := func(v bool) uint64 { - if v { - return 1 - } - return 0 - } - subTests := []struct { name string setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) @@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) { { name: "Drop", - setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), expectForward: false, }, { name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), expectForward: false, }, { name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), expectForward: true, }, { name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), expectForward: true, }, { name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), expectForward: true, }, } @@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) { }) } } + +func TestInputHookWithLocalForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + rx func(*channel.Endpoint) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv4(t, b, + checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) + }, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv6(t, b, + checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv6Addr), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) + }, + }, + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectDrop bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectDrop: false, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on arrival NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on delivered NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), + expectDrop: false, + }, + + { + name: "Drop with input NIC filtering on other NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), + expectDrop: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + }) + + test.rx(e1) + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { + t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) + } + if got := ip2Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) + } + + if p, ok := e1.Read(); ok == subTest.expectDrop { + t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) + } else if !subTest.expectDrop { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + if p, ok := e2.Read(); ok { + t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2c65b737d..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } switch { + case s.flags.Contains(header.TCPFlagRst): + e.stack.Stats().DroppedPackets.Increment() + return nil + case s.flags == header.TCPFlagSyn: if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() @@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() return nil - case (s.flags & header.TCPFlagAck) != 0: + case s.flags.Contains(header.TCPFlagAck): if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops @@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil default: + e.stack.Stats().DroppedPackets.Increment() return nil } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 570e5081c..2137ebc25 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -406,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } @@ -1474,11 +1474,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return &tcpip.ErrConnectionReset{} } - if n¬ifyClose != 0 && closeTimer == nil { - if e.EndpointState() == StateFinWait2 && e.closed { + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is still + // established. This can occur when notifyClose + // was asserted just before becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + if closeTimer == nil { + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e7ede7662..9bbe9bc3e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6238,6 +6238,54 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } +func TestListenDropIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) + + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + initialDropped := stats.DroppedPackets.Value() + + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) + + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } +} + func TestEndpointBindListenAcceptState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() |