summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-02-05 18:41:37 -0800
committergVisor bot <gvisor-bot@google.com>2021-02-05 18:44:04 -0800
commit83b764d9d2193e2e01f3a60792f3468c1843c5a8 (patch)
tree3cb303660a15cfd0b2150ee3d93966636dbb3054
parent120c8e34687129c919ae45263c14b239a0a5d343 (diff)
Batch write packets after iptables checks
After IPTables checks a batch of packets, we can write packets that are not dropped or locally destined as a batch instead of individually. This previously caused a bug since WritePacket* functions expect to take ownership of passed PacketBuffer{List}. WritePackets assumed the list of PacketBuffers will not be invalidated when calling WritePacket for each PacketBuffer in the list, but this is not true. WritePacket may add the passed PacketBuffer into a different list which would modify the PacketBuffer in such a way that it no longer points to the next PacketBuffer to write. Example: Given a PB list of PB_a -> PB_b -> PB_c WritePackets may be iterating over the list and calling WritePacket for each PB. When WritePacket takes PB_a, it may add it to a new list which would update pointers such that PB_a no longer points to PB_b. Test: integration_test.TestIPTableWritePackets PiperOrigin-RevId: 355969560
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go60
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go58
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go310
3 files changed, 358 insertions, 70 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e1e05e39c..14cf786d2 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -441,47 +441,37 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
- if len(dropped) == 0 && len(natPkts) == 0 {
- // Fast path: If no packets are to be dropped then we can just invoke the
- // faster WritePackets API directly.
- n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- stats.PacketsSent.IncrementBy(uint64(n))
- if err != nil {
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
- }
- return n, err
- }
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ for pkt := range dropped {
+ pkts.Remove(pkt)
+ }
- // Slow path as we are dropping some packets in the batch degrade to
- // emitting one packet at a time.
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if _, ok := dropped[pkt]; ok {
+ // The NAT-ed packets may now be destined for us.
+ locallyDelivered := 0
+ for pkt := range natPkts {
+ ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
+ if ep == nil {
+ // The NAT-ed packet is still destined for some remote node.
continue
}
- if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
- n++
- continue
- }
- }
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- stats.PacketsSent.IncrementBy(uint64(n))
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
- // Dropped packets aren't errors, so include them in
- // the return value.
- return n + len(dropped), err
- }
- n++
+
+ // Do not send the locally destined packet out the NIC.
+ pkts.Remove(pkt)
+
+ // Deliver the packet locally.
+ ep.(*endpoint).handleLocalPacket(pkt, true)
+ locallyDelivered++
+
}
- stats.PacketsSent.IncrementBy(uint64(n))
+
+ // The rest of the packets can be delivered to the NIC as a batch.
+ pktsLen := pkts.Len()
+ written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ stats.PacketsSent.IncrementBy(uint64(written))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
+
// Dropped packets aren't errors, so include them in the return value.
- return n + len(dropped), nil
+ return locallyDelivered + written + len(dropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 5cad546b8..c21c587ba 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -742,48 +742,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
- if len(dropped) == 0 && len(natPkts) == 0 {
- // Fast path: If no packets are to be dropped then we can just invoke the
- // faster WritePackets API directly.
- n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- stats.PacketsSent.IncrementBy(uint64(n))
- if err != nil {
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
- }
- return n, err
- }
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ for pkt := range dropped {
+ pkts.Remove(pkt)
+ }
- // Slow path as we are dropping some packets in the batch degrade to
- // emitting one packet at a time.
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if _, ok := dropped[pkt]; ok {
+ // The NAT-ed packets may now be destined for us.
+ locallyDelivered := 0
+ for pkt := range natPkts {
+ ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
+ if ep == nil {
+ // The NAT-ed packet is still destined for some remote node.
continue
}
- if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
- n++
- continue
- }
- }
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- stats.PacketsSent.IncrementBy(uint64(n))
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
- // Dropped packets aren't errors, so include them in
- // the return value.
- return n + len(dropped), err
- }
- n++
+
+ // Do not send the locally destined packet out the NIC.
+ pkts.Remove(pkt)
+
+ // Deliver the packet locally.
+ ep.(*endpoint).handleLocalPacket(pkt, true)
+ locallyDelivered++
}
- stats.PacketsSent.IncrementBy(uint64(n))
+ // The rest of the packets can be delivered to the NIC as a batch.
+ pktsLen := pkts.Len()
+ written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ stats.PacketsSent.IncrementBy(uint64(written))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
+
// Dropped packets aren't errors, so include them in the return value.
- return n + len(dropped), nil
+ return locallyDelivered + written + len(dropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 21a8dd291..b56706357 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type inputIfNameMatcher struct {
@@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) {
})
}
}
+
+var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
+
+// channelEndpointWithoutWritePacket is a channel endpoint that does not support
+// stack.LinkEndpoint.WritePacket.
+type channelEndpointWithoutWritePacket struct {
+ *channel.Endpoint
+
+ t *testing.T
+}
+
+func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
+ return &tcpip.ErrNotSupported{}
+}
+
+var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
+
+type udpSourcePortMatcher struct {
+ port uint16
+}
+
+func (*udpSourcePortMatcher) Name() string {
+ return "udpSourcePortMatcher"
+}
+
+func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
+ udp := header.UDP(pkt.TransportHeader().View())
+ if len(udp) < header.UDPMinimumSize {
+ // Drop immediately as the packet is invalid.
+ return false, true
+ }
+
+ return udp.SourcePort() == m.port, false
+}
+
+func TestIPTableWritePackets(t *testing.T) {
+ const (
+ nicID = 1
+
+ dropLocalPort = localPort - 1
+ acceptPackets = 2
+ dropPackets = 3
+ )
+
+ udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
+ u := header.UDP(hdr)
+ u.Encode(&header.UDPFields{
+ SrcPort: srcPort,
+ DstPort: dstPort,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
+ sum = header.Checksum(hdr, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ }
+
+ tests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func(*stack.Route) stack.PacketBufferList
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectSent uint64
+ expectOutputDropped uint64
+ }{
+ {
+ name: "IPv4 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv4 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ {
+ name: "IPv6 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channelEndpointWithoutWritePacket{
+ Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
+ t: t,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ test.setupFilter(t, s)
+
+ r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
+ }
+ defer r.Release()
+
+ pkts := test.genPacket(r)
+ pktsLen := pkts.Len()
+ if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{
+ Protocol: header.UDPProtocolNumber,
+ TTL: 64,
+ }); err != nil {
+ t.Fatalf("WritePackets(...): %s", err)
+ } else if n != pktsLen {
+ t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
+ }
+
+ if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
+ t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
+ }
+ if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
+ t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
+ }
+ })
+ }
+}