diff options
author | Kevin Krakauer <krakauer@google.com> | 2020-09-17 21:52:54 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-17 21:54:48 -0700 |
commit | 0b8d306e64f89e0d63a558f69a846d36beeea51d (patch) | |
tree | f38fdb8ab4a7016a3cbf406fbd01ec8ddd0fd4cc | |
parent | d34bda027309695e3e6fb6f92a5839cd1f21173e (diff) |
ip6tables: filter table support
`ip6tables -t filter` is now usable. NAT support will come in a future CL.
#3549
PiperOrigin-RevId: 332381801
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/netfilter.go | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/tcp_matcher.go | 32 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/udp_matcher.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 74 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6_test.go | 208 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 120 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 2 | ||||
-rw-r--r-- | test/iptables/iptables_test.go | 52 |
12 files changed, 515 insertions, 93 deletions
@@ -200,6 +200,8 @@ kvm-tests: load-basic-images .PHONY: kvm-tests iptables-tests: load-iptables + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test") @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw") @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3e1735079..871ea80ee 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -146,6 +146,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { case stack.FilterTable: table = stack.EmptyFilterTable() case stack.NATTable: + if ipv6 { + nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)") + return syserr.ErrInvalidArgument + } table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 0bfd6c1f4..844acfede 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,17 +97,33 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the stack.Check codepath as matchers are added. + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return false, false - } + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber { + return false, false } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 7ed05461d..63201201c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,19 +94,33 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. - if netHeader.TransportProtocol() != header.UDPProtocolNumber { - return false, false - } + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false } + + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index f4394749d..a75c4cdda 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -238,11 +238,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. - // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than - // only NATted packets, but removing this check short circuits broadcasts - // before they are sent out to other hosts. + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) @@ -298,7 +300,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, err } - // Slow Path as we are dropping some packets in the batch degrade to + // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e821a8bff..fc8dfea42 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -107,6 +107,31 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + return nil + } + + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() @@ -138,9 +163,46 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe e.addIPHeader(r, pb, params) } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ + } + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet @@ -169,6 +231,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + return + } + for firstHeader := true; ; firstHeader = false { extHdr, done, err := it.Next() if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 5f9822c49..354d3b60d 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -1709,3 +1709,211 @@ func TestInvalidIPv6Fragments(t *testing.T) { }) } } + +func TestWritePacketsStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + linkEP stack.LinkEndpoint + expectSent int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: &limitedEP{nPackets - 1}, + expectSent: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: 0, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: &limitedEP{nPackets}, + expectSent: nPackets - 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rt := buildRoute(t, nil, test.linkEP) + + var pbl stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(1).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pbl.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, err := rt.WritePackets(nil, pbl, stack.NetworkHeaderParams{}) + if err != nil { + t.Fatal(err) + } + + got := int(rt.Stats().IP.PacketsSent.Value()) + if got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got != nWritten { + t.Errorf("sent %d packets, WritePackets returned %d", got, nWritten) + } + }) + } +} + +func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + s.CreateNIC(1, linkEP) + const ( + src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ) + s.AddAddress(1, ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return rt +} + +// limitedEP is a link endpoint that writes up to a certain number of packets +// before returning errors. +type limitedEP struct { + limit int +} + +// MTU implements LinkEndpoint.MTU. +func (*limitedEP) MTU() uint32 { return 0 } + +// Capabilities implements LinkEndpoint.Capabilities. +func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*limitedEP) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if ep.limit == 0 { + return 0, tcpip.ErrInvalidEndpointState + } + nWritten := ep.limit + if nWritten > pkts.Len() { + nWritten = pkts.Len() + } + ep.limit -= nWritten + return nWritten, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*limitedEP) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*limitedEP) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 0e33cbe92..b6ef04d32 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -57,7 +57,72 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - tables: [numTables]Table{ + v4Tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + v6Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -166,25 +231,20 @@ func EmptyNATTable() Table { // GetTable returns a table by name. func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - // TODO(gvisor.dev/issue/3549): Enable IPv6. - if ipv6 { - return Table{}, false - } id, ok := nameToID[name] if !ok { return Table{}, false } it.mu.RLock() defer it.mu.RUnlock() - return it.tables[id], true + if ipv6 { + return it.v6Tables[id], true + } + return it.v4Tables[id], true } // ReplaceTable replaces or inserts table by name. func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - // TODO(gvisor.dev/issue/3549): Enable IPv6. - if ipv6 { - return tcpip.ErrInvalidOptionValue - } id, ok := nameToID[name] if !ok { return tcpip.ErrInvalidOptionValue @@ -198,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Err it.startReaper(reaperDelay) } it.modified = true - it.tables[id] = table + if ipv6 { + it.v6Tables[id] = table + } else { + it.v4Tables[id] = table + } return nil } @@ -221,8 +285,17 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// which address and nicName can be gathered. Currently, address is only +// needed for prerouting and nicName is only needed for output. +// +// TODO(gvisor.dev/issue/170): Dropped packets should be counted. +// // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { + if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { + return true + } // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() @@ -243,9 +316,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr if tableID == natID && pkt.NatDone { continue } - table := it.tables[tableID] + var table Table + if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber { + table = it.v6Tables[tableID] + } else { + table = it.v4Tables[tableID] + } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -256,7 +334,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -351,11 +429,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { case RuleAccept: return chainAccept @@ -372,7 +450,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -398,11 +476,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { + if !rule.Filter.match(pkt, hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -421,7 +499,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fbbd2f50f..093ee6881 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "strings" "sync" @@ -81,26 +82,25 @@ const ( // // +stateify savable type IPTables struct { - // mu protects tables, priorities, and modified. + // mu protects v4Tables, v6Tables, and modified. mu sync.RWMutex - - // tables maps tableIDs to tables. Holds builtin tables only, not user - // tables. mu must be locked for accessing. - tables [numTables]Table - - // priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. mu needs to be locked for accessing. - priorities [NumHooks][]tableID - + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. mu must be locked for accessing. + v4Tables [numTables]Table + v6Tables [numTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. It is immutable. + priorities [NumHooks][]tableID + connections ConnTrack - // reaperDone can be signalled to stop the reaper goroutine. + // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} } @@ -148,7 +148,7 @@ type Rule struct { Target Target } -// IPHeaderFilter holds basic IP filtering data common to every rule. +// IPHeaderFilter performs basic IP header matching common to every rule. // // +stateify savable type IPHeaderFilter struct { @@ -196,16 +196,43 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } -// match returns whether hdr matches the filter. -func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. +// match returns whether pkt matches the filter. +// +// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 +// or IPv6 header length. +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { + // Extract header fields. + var ( + // TODO(gvisor.dev/issue/170): Support other filter fields. + transProto tcpip.TransportProtocolNumber + dstAddr tcpip.Address + srcAddr tcpip.Address + ) + switch proto := pkt.NetworkProtocolNumber; proto { + case header.IPv4ProtocolNumber: + hdr := header.IPv4(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + case header.IPv6ProtocolNumber: + hdr := header.IPv6(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + default: + panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto)) + } + // Check the transport protocol. - if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + if fl.CheckProtocol && fl.Protocol != transProto { return false } - // Check the source and destination IPs. - if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + // Check the addresses. + if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) || + !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) { return false } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 1f1a1426b..821d3feb9 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1282,9 +1282,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } - // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 17b8beebb..1932aaeb7 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -80,7 +80,7 @@ type PacketBuffer struct { // data are held in the same underlying buffer storage. header buffer.Prependable - // NetworkProtocol is only valid when NetworkHeader is set. + // NetworkProtocolNumber is only valid when NetworkHeader is set. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol // numbers in registration APIs that take a PacketBuffer. NetworkProtocolNumber tcpip.NetworkProtocolNumber diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index e2beb30d5..398f70ecd 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -48,6 +48,13 @@ func singleTest(t *testing.T, test TestCase) { } } +// TODO(gvisor.dev/issue/3549): IPv6 NAT support. +func ipv4Test(t *testing.T, test TestCase) { + t.Run("IPv4", func(t *testing.T) { + iptablesTest(t, test, false) + }) +} + func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) @@ -72,11 +79,6 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { d.CleanUp(context.Background()) }() - // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests. - if ipv6 && dockerutil.Runtime() != "runc" { - t.Skip("gVisor ip6tables not yet implemented") - } - // Create and start the container. opts := dockerutil.RunOpts{ Image: "iptables", @@ -314,75 +316,75 @@ func TestInputInvertDestination(t *testing.T) { singleTest(t, FilterInputInvertDestination{}) } -func TestOutputDestination(t *testing.T) { +func TestFilterOutputDestination(t *testing.T) { singleTest(t, FilterOutputDestination{}) } -func TestOutputInvertDestination(t *testing.T) { +func TestFilterOutputInvertDestination(t *testing.T) { singleTest(t, FilterOutputInvertDestination{}) } func TestNATPreRedirectUDPPort(t *testing.T) { - singleTest(t, NATPreRedirectUDPPort{}) + ipv4Test(t, NATPreRedirectUDPPort{}) } func TestNATPreRedirectTCPPort(t *testing.T) { - singleTest(t, NATPreRedirectTCPPort{}) + ipv4Test(t, NATPreRedirectTCPPort{}) } func TestNATPreRedirectTCPOutgoing(t *testing.T) { - singleTest(t, NATPreRedirectTCPOutgoing{}) + ipv4Test(t, NATPreRedirectTCPOutgoing{}) } func TestNATOutRedirectTCPIncoming(t *testing.T) { - singleTest(t, NATOutRedirectTCPIncoming{}) + ipv4Test(t, NATOutRedirectTCPIncoming{}) } func TestNATOutRedirectUDPPort(t *testing.T) { - singleTest(t, NATOutRedirectUDPPort{}) + ipv4Test(t, NATOutRedirectUDPPort{}) } func TestNATOutRedirectTCPPort(t *testing.T) { - singleTest(t, NATOutRedirectTCPPort{}) + ipv4Test(t, NATOutRedirectTCPPort{}) } func TestNATDropUDP(t *testing.T) { - singleTest(t, NATDropUDP{}) + ipv4Test(t, NATDropUDP{}) } func TestNATAcceptAll(t *testing.T) { - singleTest(t, NATAcceptAll{}) + ipv4Test(t, NATAcceptAll{}) } func TestNATOutRedirectIP(t *testing.T) { - singleTest(t, NATOutRedirectIP{}) + ipv4Test(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - singleTest(t, NATOutDontRedirectIP{}) + ipv4Test(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - singleTest(t, NATOutRedirectInvert{}) + ipv4Test(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - singleTest(t, NATPreRedirectIP{}) + ipv4Test(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - singleTest(t, NATPreDontRedirectIP{}) + ipv4Test(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - singleTest(t, NATPreRedirectInvert{}) + ipv4Test(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - singleTest(t, NATRedirectRequiresProtocol{}) + ipv4Test(t, NATRedirectRequiresProtocol{}) } func TestNATLoopbackSkipsPrerouting(t *testing.T) { - singleTest(t, NATLoopbackSkipsPrerouting{}) + ipv4Test(t, NATLoopbackSkipsPrerouting{}) } func TestInputSource(t *testing.T) { @@ -419,9 +421,9 @@ func TestFilterAddrs(t *testing.T) { } func TestNATPreOriginalDst(t *testing.T) { - singleTest(t, NATPreOriginalDst{}) + ipv4Test(t, NATPreOriginalDst{}) } func TestNATOutOriginalDst(t *testing.T) { - singleTest(t, NATOutOriginalDst{}) + ipv4Test(t, NATOutOriginalDst{}) } |