summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2020-09-17 21:52:54 -0700
committergVisor bot <gvisor-bot@google.com>2020-09-17 21:54:48 -0700
commit0b8d306e64f89e0d63a558f69a846d36beeea51d (patch)
treef38fdb8ab4a7016a3cbf406fbd01ec8ddd0fd4cc /pkg/tcpip/network
parentd34bda027309695e3e6fb6f92a5839cd1f21173e (diff)
ip6tables: filter table support
`ip6tables -t filter` is now usable. NAT support will come in a future CL. #3549 PiperOrigin-RevId: 332381801
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go14
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go74
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go208
3 files changed, 288 insertions, 8 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index f4394749d..a75c4cdda 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -238,11 +238,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
- // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
- // only NATted packets, but removing this check short circuits broadcasts
- // before they are sent out to other hosts.
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
@@ -298,7 +300,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n, err
}
- // Slow Path as we are dropping some packets in the batch degrade to
+ // Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e821a8bff..fc8dfea42 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -107,6 +107,31 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
+
if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
@@ -138,9 +163,46 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
e.addIPHeader(r, pb, params)
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+
+ // Slow path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ n++
+ }
+
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ return n, nil
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
@@ -169,6 +231,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+
for firstHeader := true; ; firstHeader = false {
extHdr, done, err := it.Next()
if err != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 5f9822c49..354d3b60d 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -1709,3 +1709,211 @@ func TestInvalidIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestWritePacketsStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ linkEP stack.LinkEndpoint
+ expectSent int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: &limitedEP{nPackets},
+ expectSent: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: &limitedEP{nPackets - 1},
+ expectSent: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: &limitedEP{nPackets},
+ expectSent: 0,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: &limitedEP{nPackets},
+ expectSent: nPackets - 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ rt := buildRoute(t, nil, test.linkEP)
+
+ var pbl stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(1).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pbl.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, err := rt.WritePackets(nil, pbl, stack.NetworkHeaderParams{})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ got := int(rt.Stats().IP.PacketsSent.Value())
+ if got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got != nWritten {
+ t.Errorf("sent %d packets, WritePackets returned %d", got, nWritten)
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ s.CreateNIC(1, linkEP)
+ const (
+ src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ )
+ s.AddAddress(1, ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return rt
+}
+
+// limitedEP is a link endpoint that writes up to a certain number of packets
+// before returning errors.
+type limitedEP struct {
+ limit int
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*limitedEP) MTU() uint32 { return 0 }
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if ep.limit == 0 {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ nWritten := ep.limit
+ if nWritten > pkts.Len() {
+ nWritten = pkts.Len()
+ }
+ ep.limit -= nWritten
+ return nWritten, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*limitedEP) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*limitedEP) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}