diff options
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 60 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 58 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/iptables_test.go | 310 |
3 files changed, 358 insertions, 70 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e1e05e39c..14cf786d2 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -441,47 +441,37 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) - if len(dropped) == 0 && len(natPkts) == 0 { - // Fast path: If no packets are to be dropped then we can just invoke the - // faster WritePackets API directly. - n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - stats.PacketsSent.IncrementBy(uint64(n)) - if err != nil { - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) - } - return n, err - } stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + for pkt := range dropped { + pkts.Remove(pkt) + } - // Slow path as we are dropping some packets in the batch degrade to - // emitting one packet at a time. - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if _, ok := dropped[pkt]; ok { + // The NAT-ed packets may now be destined for us. + locallyDelivered := 0 + for pkt := range natPkts { + ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress()) + if ep == nil { + // The NAT-ed packet is still destined for some remote node. continue } - if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) - n++ - continue - } - } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - stats.PacketsSent.IncrementBy(uint64(n)) - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n - len(dropped))) - // Dropped packets aren't errors, so include them in - // the return value. - return n + len(dropped), err - } - n++ + + // Do not send the locally destined packet out the NIC. + pkts.Remove(pkt) + + // Deliver the packet locally. + ep.(*endpoint).handleLocalPacket(pkt, true) + locallyDelivered++ + } - stats.PacketsSent.IncrementBy(uint64(n)) + + // The rest of the packets can be delivered to the NIC as a batch. + pktsLen := pkts.Len() + written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + stats.PacketsSent.IncrementBy(uint64(written)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) + // Dropped packets aren't errors, so include them in the return value. - return n + len(dropped), nil + return locallyDelivered + written + len(dropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 5cad546b8..c21c587ba 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -742,48 +742,36 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) - if len(dropped) == 0 && len(natPkts) == 0 { - // Fast path: If no packets are to be dropped then we can just invoke the - // faster WritePackets API directly. - n, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) - stats.PacketsSent.IncrementBy(uint64(n)) - if err != nil { - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) - } - return n, err - } stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + for pkt := range dropped { + pkts.Remove(pkt) + } - // Slow path as we are dropping some packets in the batch degrade to - // emitting one packet at a time. - n := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if _, ok := dropped[pkt]; ok { + // The NAT-ed packets may now be destined for us. + locallyDelivered := 0 + for pkt := range natPkts { + ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) + if ep == nil { + // The NAT-ed packet is still destined for some remote node. continue } - if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { - // Since we rewrote the packet but it is being routed back to us, we - // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) - n++ - continue - } - } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { - stats.PacketsSent.IncrementBy(uint64(n)) - stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n + len(dropped))) - // Dropped packets aren't errors, so include them in - // the return value. - return n + len(dropped), err - } - n++ + + // Do not send the locally destined packet out the NIC. + pkts.Remove(pkt) + + // Deliver the packet locally. + ep.(*endpoint).handleLocalPacket(pkt, true) + locallyDelivered++ } - stats.PacketsSent.IncrementBy(uint64(n)) + // The rest of the packets can be delivered to the NIC as a batch. + pktsLen := pkts.Len() + written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + stats.PacketsSent.IncrementBy(uint64(written)) + stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) + // Dropped packets aren't errors, so include them in the return value. - return n + len(dropped), nil + return locallyDelivered + written + len(dropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 21a8dd291..b56706357 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) type inputIfNameMatcher struct { @@ -334,3 +335,312 @@ func TestIPTablesStatsForInput(t *testing.T) { }) } } + +var _ stack.LinkEndpoint = (*channelEndpointWithoutWritePacket)(nil) + +// channelEndpointWithoutWritePacket is a channel endpoint that does not support +// stack.LinkEndpoint.WritePacket. +type channelEndpointWithoutWritePacket struct { + *channel.Endpoint + + t *testing.T +} + +func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { + c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") + return &tcpip.ErrNotSupported{} +} + +var _ stack.Matcher = (*udpSourcePortMatcher)(nil) + +type udpSourcePortMatcher struct { + port uint16 +} + +func (*udpSourcePortMatcher) Name() string { + return "udpSourcePortMatcher" +} + +func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { + udp := header.UDP(pkt.TransportHeader().View()) + if len(udp) < header.UDPMinimumSize { + // Drop immediately as the packet is invalid. + return false, true + } + + return udp.SourcePort() == m.port, false +} + +func TestIPTableWritePackets(t *testing.T) { + const ( + nicID = 1 + + dropLocalPort = localPort - 1 + acceptPackets = 2 + dropPackets = 3 + ) + + udpHdr := func(hdr buffer.View, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { + u := header.UDP(hdr) + u.Encode(&header.UDPFields{ + SrcPort: srcPort, + DstPort: dstPort, + Length: header.UDPMinimumSize, + }) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) + sum = header.Checksum(hdr, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + } + + tests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack) + genPacket func(*stack.Route) stack.PacketBufferList + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectSent uint64 + expectOutputDropped uint64 + }{ + { + name: "IPv4 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv4 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { + t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv4ProtocolNumber, + remoteAddr: dstAddrV4, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + { + name: "IPv6 Accept", + setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: 1, + expectOutputDropped: 0, + }, + { + name: "IPv6 Drop Other Port", + setupFilter: func(t *testing.T, s *stack.Stack) { + t.Helper() + + table := stack.Table{ + Rules: []stack.Rule{ + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, + Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + { + Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, + }, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: stack.HookUnset, + stack.Input: 0, + stack.Forward: 1, + stack.Output: 2, + stack.Postrouting: stack.HookUnset, + }, + } + + if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { + t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + } + }, + genPacket: func(r *stack.Route) stack.PacketBufferList { + var pkts stack.PacketBufferList + + for i := 0; i < acceptPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, localPort, remotePort) + pkts.PushFront(pkt) + } + for i := 0; i < dropPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), + }) + hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) + udpHdr(hdr, r.LocalAddress, r.RemoteAddress, dropLocalPort, remotePort) + pkts.PushFront(pkt) + } + + return pkts + }, + proto: header.IPv6ProtocolNumber, + remoteAddr: dstAddrV6, + expectSent: acceptPackets, + expectOutputDropped: dropPackets, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channelEndpointWithoutWritePacket{ + Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), + t: t, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, srcAddrV6); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, srcAddrV6, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, srcAddrV4); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, srcAddrV4, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + test.setupFilter(t, s) + + r, err := s.FindRoute(nicID, "", test.remoteAddr, test.proto, false) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) + } + defer r.Release() + + pkts := test.genPacket(r) + pktsLen := pkts.Len() + if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ + Protocol: header.UDPProtocolNumber, + TTL: 64, + }); err != nil { + t.Fatalf("WritePackets(...): %s", err) + } else if n != pktsLen { + t.Fatalf("got WritePackets(...) = %d, want = %d", n, pktsLen) + } + + if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { + t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) + } + if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { + t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) + } + }) + } +} |