summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go14
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go12
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go20
-rw-r--r--pkg/tcpip/stack/packet_buffer.go2
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go285
-rw-r--r--pkg/tcpip/transport/tcp/accept.go14
-rw-r--r--pkg/tcpip/transport/tcp/connect.go22
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go48
8 files changed, 337 insertions, 80 deletions
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index bddb1d0a2..735c28da1 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,7 +41,6 @@ package fdbased
import (
"fmt"
- "math"
"sync/atomic"
"golang.org/x/sys/unix"
@@ -196,8 +195,12 @@ type Options struct {
// option for an FD with a fanoutID already in use by another FD for a different
// NIC will return an EINVAL.
//
+// Since fanoutID must be unique within the network namespace, we start with
+// the PID to avoid collisions. The only way to be sure of avoiding collisions
+// is to run in a new network namespace.
+//
// Must be accessed using atomic operations.
-var fanoutID int32 = 0
+var fanoutID int32 = int32(unix.Getpid())
// New creates a new fd-based endpoint.
//
@@ -292,11 +295,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // See: PACKET_FANOUT_MAX in net/packet/internal.h
- const packetFanoutMax = 1 << 16
- if fID > packetFanoutMax {
- return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16)
- }
// Enable PACKET_FANOUT mode if the underlying socket is of type
// AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
// prevent gvisor from receiving fragmented packets and the host does the
@@ -317,7 +315,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
//
// See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881
const fanoutType = unix.PACKET_FANOUT_HASH
- fanoutArg := int(fID) | fanoutType<<16
+ fanoutArg := (int(fID) & 0xffff) | fanoutType<<16
if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6bee55634..c99297a51 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -720,7 +720,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -836,7 +837,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -855,10 +856,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -920,8 +921,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 6103574f7..12763add6 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -991,7 +991,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -1104,7 +1105,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -1123,10 +1124,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1175,8 +1176,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1627,8 +1627,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1669,8 +1669,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 4ca702121..9192d8433 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -134,7 +134,7 @@ type PacketBuffer struct {
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
- // NICID is the ID of the interface the network packet was received at.
+ // NICID is the ID of the last interface the network packet was handled at.
NICID tcpip.NICID
// RXTransportChecksumValidated indicates that transport checksum verification
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 07ba2b837..f9ab7d0af 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
// Make sure the packet is not dropped by the next rule.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr
checker.ICMPv6Type(header.ICMPv6EchoReply)))
}
+func boolToInt(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[hook]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+}
+
func TestForwardingHook(t *testing.T) {
const (
nicID1 = 1
@@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) {
},
}
- setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[stack.Forward]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
- }
-
- boolToInt := func(v bool) uint64 {
- if v {
- return 1
- }
- return 0
- }
-
subTests := []struct {
name string
setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
@@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) {
{
name: "Drop",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
expectForward: false,
},
{
name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
expectForward: false,
},
{
name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
expectForward: true,
},
{
name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
expectForward: true,
},
{
name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
expectForward: true,
},
}
@@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) {
})
}
}
+
+func TestInputHookWithLocalForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ rx func(*channel.Endpoint)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+ },
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv6Addr),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectDrop bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on arrival NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on delivered NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop with input NIC filtering on other NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectDrop: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ })
+
+ test.rx(e1)
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
+ t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip2Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ if p, ok := e1.Read(); ok == subTest.expectDrop {
+ t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
+ } else if !subTest.expectDrop {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ if p, ok := e2.Read(); ok {
+ t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 2c65b737d..d807b13b7 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
switch {
+ case s.flags.Contains(header.TCPFlagRst):
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+
case s.flags == header.TCPFlagSyn:
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
@@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
return nil
- case (s.flags & header.TCPFlagAck) != 0:
+ case s.flags.Contains(header.TCPFlagAck):
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
mss: rcvdSynOptions.MSS,
})
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) {
+ s.incRef()
+ n.newSegmentWaker.Assert()
+ }
+
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
// the application is slow to accept or stops
@@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
default:
+ e.stack.Stats().DroppedPackets.Increment()
return nil
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 570e5081c..2137ebc25 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -406,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
h.ep.transitionToStateEstablishedLocked(h)
- // If the segment has data then requeue it for the receiver
- // to process it again once main loop is started.
- if s.data.Size() > 0 {
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) {
s.incRef()
- h.ep.enqueueSegment(s)
+ h.ep.newSegmentWaker.Assert()
}
return nil
}
@@ -1474,11 +1474,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return &tcpip.ErrConnectionReset{}
}
- if n&notifyClose != 0 && closeTimer == nil {
- if e.EndpointState() == StateFinWait2 && e.closed {
+ if n&notifyClose != 0 && e.closed {
+ switch e.EndpointState() {
+ case StateEstablished:
+ // Perform full shutdown if the endpoint is still
+ // established. This can occur when notifyClose
+ // was asserted just before becoming established.
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ case StateFinWait2:
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ if closeTimer == nil {
+ closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ }
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index e7ede7662..9bbe9bc3e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6238,6 +6238,54 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
}
}
+func TestListenDropIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ c.Create(-1 /*epRcvBuf*/)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(1 /*backlog*/); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ initialDropped := stats.DroppedPackets.Value()
+
+ // Send RST, FIN segments, that are expected to be dropped by the listener.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ })
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin,
+ })
+
+ // To ensure that the RST, FIN sent earlier are indeed received and ignored
+ // by the listener, send a SYN and wait for the SYN to be ACKd.
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ })
+ checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs)+1),
+ ))
+
+ if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
+ t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
func TestEndpointBindListenAcceptState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()