summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go60
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go58
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go310
3 files changed, 358 insertions, 70 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e1e05e39c..14cf786d2 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -441,47 +441,37 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
- if len(dropped) == 0 && len(natPkts) == 0 {
- // Fast path: If no packets are to be dropped then we can just invoke the
- // faster WritePackets API directly.
- n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- stats.PacketsSent.IncrementBy(uint64(n))
- if err != nil {
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
- }
- return n, err
- }
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ for pkt := range dropped {
+ pkts.Remove(pkt)
+ }
- // Slow path as we are dropping some packets in the batch degrade to
- // emitting one packet at a time.
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if _, ok := dropped[pkt]; ok {
+ // The NAT-ed packets may now be destined for us.
+ locallyDelivered := 0
+ for pkt := range natPkts {
+ ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
+ if ep == nil {
+ // The NAT-ed packet is still destined for some remote node.
continue
}
- if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
- n++
- continue
- }
- }
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- stats.PacketsSent.IncrementBy(uint64(n))
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped)))
- // Dropped packets aren't errors, so include them in
- // the return value.
- return n + len(dropped), err
- }
- n++
+
+ // Do not send the locally destined packet out the NIC.
+ pkts.Remove(pkt)
+
+ // Deliver the packet locally.
+ ep.(*endpoint).handleLocalPacket(pkt, true)
+ locallyDelivered++
+
}
- stats.PacketsSent.IncrementBy(uint64(n))
+
+ // The rest of the packets can be delivered to the NIC as a batch.
+ pktsLen := pkts.Len()
+ written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ stats.PacketsSent.IncrementBy(uint64(written))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
+
// Dropped packets aren't errors, so include them in the return value.
- return n + len(dropped), nil
+ return locallyDelivered + written + len(dropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 5cad546b8..c21c587ba 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -742,48 +742,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
- if len(dropped) == 0 && len(natPkts) == 0 {
- // Fast path: If no packets are to be dropped then we can just invoke the
- // faster WritePackets API directly.
- n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
- stats.PacketsSent.IncrementBy(uint64(n))
- if err != nil {
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n))
- }
- return n, err
- }
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+ for pkt := range dropped {
+ pkts.Remove(pkt)
+ }
- // Slow path as we are dropping some packets in the batch degrade to
- // emitting one packet at a time.
- n := 0
- for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if _, ok := dropped[pkt]; ok {
+ // The NAT-ed packets may now be destined for us.
+ locallyDelivered := 0
+ for pkt := range natPkts {
+ ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
+ if ep == nil {
+ // The NAT-ed packet is still destined for some remote node.
continue
}
- if _, ok := natPkts[pkt]; ok {
- netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
- // Since we rewrote the packet but it is being routed back to us, we
- // can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
- n++
- continue
- }
- }
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
- stats.PacketsSent.IncrementBy(uint64(n))
- stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped)))
- // Dropped packets aren't errors, so include them in
- // the return value.
- return n + len(dropped), err
- }
- n++
+
+ // Do not send the locally destined packet out the NIC.
+ pkts.Remove(pkt)
+
+ // Deliver the packet locally.
+ ep.(*endpoint).handleLocalPacket(pkt, true)
+ locallyDelivered++
}
- stats.PacketsSent.IncrementBy(uint64(n))
+ // The rest of the packets can be delivered to the NIC as a batch.
+ pktsLen := pkts.Len()
+ written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
+ stats.PacketsSent.IncrementBy(uint64(written))
+ stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
+
// Dropped packets aren't errors, so include them in the return value.
- return n + len(dropped), nil
+ return locallyDelivered + written + len(dropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 21a8dd291..b56706357 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
type inputIfNameMatcher struct {
@@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) {
})
}
}
+
+var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil)
+
+// channelEndpointWithoutWritePacket is a channel endpoint that does not support
+// stack.LinkEndpoint.WritePacket.
+type channelEndpointWithoutWritePacket struct {
+ *channel.Endpoint
+
+ t *testing.T
+}
+
+func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error {
+ c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets")
+ return &tcpip.ErrNotSupported{}
+}
+
+var _ stack.Matcher = (*udpSourcePortMatcher)(nil)
+
+type udpSourcePortMatcher struct {
+ port uint16
+}
+
+func (*udpSourcePortMatcher) Name() string {
+ return "udpSourcePortMatcher"
+}
+
+func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) {
+ udp := header.UDP(pkt.TransportHeader().View())
+ if len(udp) < header.UDPMinimumSize {
+ // Drop immediately as the packet is invalid.
+ return false, true
+ }
+
+ return udp.SourcePort() == m.port, false
+}
+
+func TestIPTableWritePackets(t *testing.T) {
+ const (
+ nicID = 1
+
+ dropLocalPort = localPort - 1
+ acceptPackets = 2
+ dropPackets = 3
+ )
+
+ udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) {
+ u := header.UDP(hdr)
+ u.Encode(&header.UDPFields{
+ SrcPort: srcPort,
+ DstPort: dstPort,
+ Length: header.UDPMinimumSize,
+ })
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize)
+ sum = header.Checksum(hdr, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ }
+
+ tests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack)
+ genPacket func(*stack.Route) stack.PacketBufferList
+ proto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectSent uint64
+ expectOutputDropped uint64
+ }{
+ {
+ name: "IPv4 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv4 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv4ProtocolNumber,
+ remoteAddr: dstAddrV4,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ {
+ name: "IPv6 Accept",
+ setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: 1,
+ expectOutputDropped: 0,
+ },
+ {
+ name: "IPv6 Drop Other Port",
+ setupFilter: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}},
+ Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ {
+ Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: stack.HookUnset,
+ stack.Input: 0,
+ stack.Forward: 1,
+ stack.Output: 2,
+ stack.Postrouting: stack.HookUnset,
+ },
+ }
+
+ if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
+ t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ }
+ },
+ genPacket: func(r *stack.Route) stack.PacketBufferList {
+ var pkts stack.PacketBufferList
+
+ for i := 0; i < acceptPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+ for i := 0; i < dropPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize),
+ })
+ hdr := pkt.TransportHeader().Push(header.UDPMinimumSize)
+ udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort)
+ pkts.PushFront(pkt)
+ }
+
+ return pkts
+ },
+ proto: header.IPv6ProtocolNumber,
+ remoteAddr: dstAddrV6,
+ expectSent: acceptPackets,
+ expectOutputDropped: dropPackets,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ e := channelEndpointWithoutWritePacket{
+ Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr),
+ t: t,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ test.setupFilter(t, s)
+
+ r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err)
+ }
+ defer r.Release()
+
+ pkts := test.genPacket(r)
+ pktsLen := pkts.Len()
+ if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{
+ Protocol: header.UDPProtocolNumber,
+ TTL: 64,
+ }); err != nil {
+ t.Fatalf("WritePackets(...): %s", err)
+ } else if n != pktsLen {
+ t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen)
+ }
+
+ if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent {
+ t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent)
+ }
+ if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped {
+ t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped)
+ }
+ })
+ }
+}