summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/header/parse/BUILD15
-rw-r--r--pkg/tcpip/header/parse/parse.go166
-rw-r--r--pkg/tcpip/link/sniffer/BUILD1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go60
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go7
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go56
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go249
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go153
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go247
-rw-r--r--pkg/tcpip/stack/iptables.go118
-rw-r--r--pkg/tcpip/stack/iptables_types.go65
-rw-r--r--pkg/tcpip/stack/nic.go4
-rw-r--r--pkg/tcpip/stack/packet_buffer.go2
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go22
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go23
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go23
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go19
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go18
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/udp/protocol.go4
26 files changed, 1069 insertions, 213 deletions
diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD
new file mode 100644
index 000000000..2adee9288
--- /dev/null
+++ b/pkg/tcpip/header/parse/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "parse",
+ srcs = ["parse.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
new file mode 100644
index 000000000..522135557
--- /dev/null
+++ b/pkg/tcpip/header/parse/parse.go
@@ -0,0 +1,166 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package parse provides utilities to parse packets.
+package parse
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// ARP populates pkt's network header with an ARP header found in
+// pkt.Data.
+//
+// Returns true if the header was successfully parsed.
+func ARP(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.NetworkHeader().Consume(header.ARPSize)
+ if ok {
+ pkt.NetworkProtocolNumber = header.ARPProtocolNumber
+ }
+ return ok
+}
+
+// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network
+// header with the IPv4 header.
+//
+// Returns true if the header was successfully parsed.
+func IPv4(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // Header may have options, determine the true header length.
+ headerLen := int(ipHdr.HeaderLength())
+ if headerLen < header.IPv4MinimumSize {
+ // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
+ // order for the packet to be valid. Figure out if we want to reject this
+ // case.
+ headerLen = header.IPv4MinimumSize
+ }
+ hdr, ok = pkt.NetworkHeader().Consume(headerLen)
+ if !ok {
+ return false
+ }
+ ipHdr = header.IPv4(hdr)
+
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return true
+}
+
+// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
+// header with the IPv6 header.
+func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, 0, 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ var nextHdr tcpip.TransportProtocolNumber
+ var extensionsSize int
+
+traverseExtensions:
+ for {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ break
+ }
+
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ if fragID == 0 && fragOffset == 0 && !fragMore {
+ fragID = extHdr.ID()
+ fragOffset = extHdr.FragmentOffset()
+ fragMore = extHdr.More()
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader().
+ hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ return nextHdr, fragID, fragOffset, fragMore, true
+}
+
+// UDP parses a UDP packet found in pkt.Data and populates pkt's transport
+// header with the UDP header.
+//
+// Returns true if the header was successfully parsed.
+func UDP(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ return ok
+}
+
+// TCP parses a TCP packet found in pkt.Data and populates pkt's transport
+// header with the TCP header.
+//
+// Returns true if the header was successfully parsed.
+func TCP(pkt *stack.PacketBuffer) bool {
+ // TCP header is variable length, peek at it first.
+ hdrLen := header.TCPMinimumSize
+ hdr, ok := pkt.Data.PullUp(hdrLen)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
+ // packets.
+ hdrLen = offset
+ }
+
+ _, ok = pkt.TransportHeader().Consume(hdrLen)
+ return ok
+}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 7cbc305e7..4aac12a8c 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 4fb127978..560477926 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var transProto uint8
src := tcpip.Address("unknown")
dst := tcpip.Address("unknown")
- id := 0
- size := uint16(0)
+ var size uint16
+ var id uint32
var fragmentOffset uint16
var moreFragments bool
- // Examine the packet using a new VV. Backing storage must not be written.
- vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
+ // Clone the packet buffer to not modify the original.
+ //
+ // We don't clone the original packet buffer so that the new packet buffer
+ // does not have any of its headers set.
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
switch protocol {
case header.IPv4ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return
}
- ipv4 := header.IPv4(hdr)
+
+ ipv4 := header.IPv4(pkt.NetworkHeader().View())
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- vv.TrimFront(int(ipv4.HeaderLength()))
- id = int(ipv4.ID())
+ id = uint32(ipv4.ID())
case header.IPv6ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return
}
- ipv6 := header.IPv6(hdr)
+
+ ipv6 := header.IPv6(pkt.NetworkHeader().View())
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
- transProto = ipv6.NextHeader()
+ transProto = uint8(proto)
size = ipv6.PayloadLength()
- vv.TrimFront(header.IPv6MinimumSize)
+ id = fragID
+ moreFragments = fragMore
+ fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- hdr, ok := vv.PullUp(header.ARPSize)
- if !ok {
+ if parse.ARP(pkt) {
return
}
- vv.TrimFront(header.ARPSize)
- arp := header.ARP(hdr)
+
+ arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
"%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
@@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
break
}
@@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
if !ok {
break
}
@@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.UDPProtocolNumber:
transName = "udp"
- hdr, ok := vv.PullUp(header.UDPMinimumSize)
- if !ok {
+ if ok := parse.UDP(pkt); !ok {
break
}
- udp := header.UDP(hdr)
+
+ udp := header.UDP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.TCPProtocolNumber:
transName = "tcp"
- hdr, ok := vv.PullUp(header.TCPMinimumSize)
- if !ok {
+ if ok := parse.TCP(pkt); !ok {
break
}
- tcp := header.TCP(hdr)
+
+ tcp := header.TCP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
- if offset > vv.Size() && !moreFragments {
- details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size())
+ if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size)
break
}
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index 82c073e32..b40dde96b 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7aaee08c4..cb9225bd7 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -234,11 +235,7 @@ func (*protocol) Wait() {}
// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- _, ok = pkt.NetworkHeader().Consume(header.ARPSize)
- if !ok {
- return 0, false, false
- }
- return 0, false, true
+ return 0, false, parse.ARP(pkt)
}
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index c82593e71..f9c2aa980 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -13,6 +13,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/network/hash",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index f4394749d..b14b356d6 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/network/hash"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -235,14 +236,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
ipt := e.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
}
- // If the packet is manipulated as per NAT Ouput rules, handle packet
- // based on destination address and do not send the packet to link layer.
- // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
- // only NATted packets, but removing this check short circuits broadcasts
- // before they are sent out to other hosts.
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
@@ -297,8 +301,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
return n, err
}
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
- // Slow Path as we are dropping some packets in the batch degrade to
+ // Slow path as we are dropping some packets in the batch degrade to
// emitting one packet at a time.
n := 0
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
@@ -318,12 +323,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
}
n++
}
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, nil
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -392,6 +400,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
return
}
@@ -527,37 +536,14 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return 0, false, false
- }
- ipHdr := header.IPv4(hdr)
-
- // Header may have options, determine the true header length.
- headerLen := int(ipHdr.HeaderLength())
- if headerLen < header.IPv4MinimumSize {
- // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
- // order for the packet to be valid. Figure out if we want to reject this
- // case.
- headerLen = header.IPv4MinimumSize
- }
- hdr, ok = pkt.NetworkHeader().Consume(headerLen)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return 0, false, false
}
- ipHdr = header.IPv4(hdr)
-
- // If this is a fragment, don't bother parsing the transport header.
- parseTransportHeader := true
- if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
- parseTransportHeader = false
- }
- pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
- return ipHdr.TransportProtocol(), parseTransportHeader, true
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 5e50558e8..b14bc98e8 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -1046,3 +1046,252 @@ func TestReceiveFragments(t *testing.T) {
})
}
}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ linkEP func() stack.LinkEndpoint
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ // Parameterize the tests to run with both WritePacket and WritePackets.
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ rt := buildRoute(t, nil, test.linkEP())
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ s.CreateNIC(1, linkEP)
+ const (
+ src = "\x10\x00\x00\x01"
+ dst = "\x10\x00\x00\x02"
+ )
+ s.AddAddress(1, ipv4.ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return rt
+}
+
+// limitedEP is a link endpoint that writes up to a certain number of packets
+// before returning errors.
+type limitedEP struct {
+ limit int
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*limitedEP) MTU() uint32 {
+ // Give an MTU that won't cause fragmentation for IPv4+UDP.
+ return header.IPv4MinimumSize + header.UDPMinimumSize
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if ep.limit == 0 {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ nWritten := ep.limit
+ if nWritten > pkts.Len() {
+ nWritten = pkts.Len()
+ }
+ ep.limit -= nWritten
+ return nWritten, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*limitedEP) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*limitedEP) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index bcc64994e..cd5fe3ea8 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -13,6 +13,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index e821a8bff..ee64d92d8 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -107,6 +108,32 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesOutputDropped.Increment()
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Output rules, handle packet
+ // based on destination address and do not send the packet to link
+ // layer.
+ //
+ // TODO(gvisor.dev/issue/170): We should do this for every
+ // packet, rather than only NATted packets, but removing this check
+ // short circuits broadcasts before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
+
if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
@@ -121,8 +148,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ return err
+ }
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
+ return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -138,9 +168,50 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
e.addIPHeader(r, pb, params)
}
- n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
+
+ // Slow path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv6(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ // Dropped packets aren't errors, so include them in
+ // the return value.
+ return n + len(dropped), err
+ }
+ n++
+ }
+
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ // Dropped packets aren't errors, so include them in the return value.
+ return n + len(dropped), nil
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
@@ -169,6 +240,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
hasFragmentHeader := false
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ r.Stats().IP.IPTablesInputDropped.Increment()
+ return
+ }
+
for firstHeader := true; ; firstHeader = false {
extHdr, done, err := it.Next()
if err != nil {
@@ -504,75 +584,14 @@ func (*protocol) Close() {}
// Wait implements stack.TransportProtocol.Wait.
func (*protocol) Wait() {}
-// Parse implements stack.TransportProtocol.Parse.
+// Parse implements stack.NetworkProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return 0, false, false
}
- ipHdr := header.IPv6(hdr)
-
- // dataClone consists of:
- // - Any IPv6 header bytes after the first 40 (i.e. extensions).
- // - The transport header, if present.
- // - Any other payload data.
- views := [8]buffer.View{}
- dataClone := pkt.Data.Clone(views[:])
- dataClone.TrimFront(header.IPv6MinimumSize)
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
-
- // Iterate over the IPv6 extensions to find their length.
- //
- // Parsing occurs again in HandlePacket because we don't track the
- // extensions in PacketBuffer. Unfortunately, that means HandlePacket
- // has to do the parsing work again.
- var nextHdr tcpip.TransportProtocolNumber
- foundNext := true
- extensionsSize := 0
-traverseExtensions:
- for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
- if err != nil {
- break
- }
- // If we exhaust the extension list, the entire packet is the IPv6 header
- // and (possibly) extensions.
- if done {
- extensionsSize = dataClone.Size()
- foundNext = false
- break
- }
-
- switch extHdr := extHdr.(type) {
- case header.IPv6FragmentExtHdr:
- // If this is an atomic fragment, we don't have to treat it specially.
- if !extHdr.More() && extHdr.FragmentOffset() == 0 {
- continue
- }
- // This is a non-atomic fragment and has to be re-assembled before we can
- // examine the payload for a transport header.
- foundNext = false
-
- case header.IPv6RawPayloadHeader:
- // We've found the payload after any extensions.
- extensionsSize = dataClone.Size() - extHdr.Buf.Size()
- nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
- break traverseExtensions
-
- default:
- // Any other extension is a no-op, keep looping until we find the payload.
- }
- }
-
- // Put the IPv6 header with extensions in pkt.NetworkHeader().
- hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
- if !ok {
- panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
- }
- ipHdr = header.IPv6(hdr)
- pkt.Data.CapLength(int(ipHdr.PayloadLength()))
- pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
- return nextHdr, foundNext, true
+ return proto, !fragMore && fragOffset == 0, true
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 5f9822c49..9eea1de8d 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -1709,3 +1709,250 @@ func TestInvalidIPv6Fragments(t *testing.T) {
})
}
}
+
+func TestWriteStats(t *testing.T) {
+ const nPackets = 3
+ tests := []struct {
+ name string
+ setup func(*testing.T, *stack.Stack)
+ linkEP func() stack.LinkEndpoint
+ expectSent int
+ expectDropped int
+ expectWritten int
+ }{
+ {
+ name: "Accept all",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets,
+ expectDropped: 0,
+ expectWritten: nPackets,
+ }, {
+ name: "Accept all with error",
+ // No setup needed, tables accept everything by default.
+ setup: func(*testing.T, *stack.Stack) {},
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} },
+ expectSent: nPackets - 1,
+ expectDropped: 0,
+ expectWritten: nPackets - 1,
+ }, {
+ name: "Drop all",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: 0,
+ expectDropped: nPackets,
+ expectWritten: nPackets,
+ }, {
+ name: "Drop some",
+ setup: func(t *testing.T, stk *stack.Stack) {
+ // Install Output DROP rule that matches only 1
+ // of the 3 packets.
+ t.Helper()
+ ipt := stk.IPTables()
+ filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
+ if !ok {
+ t.Fatalf("failed to find filter table")
+ }
+ // We'll match and DROP the last packet.
+ ruleIdx := filter.BuiltinChains[stack.Output]
+ filter.Rules[ruleIdx].Target = stack.DropTarget{}
+ filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
+ // Make sure the next rule is ACCEPT.
+ filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{}
+ if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ t.Fatalf("failed to replace table: %v", err)
+ }
+ },
+ linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} },
+ expectSent: nPackets - 1,
+ expectDropped: 1,
+ expectWritten: nPackets,
+ },
+ }
+
+ writers := []struct {
+ name string
+ writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error)
+ }{
+ {
+ name: "WritePacket",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ nWritten := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil {
+ return nWritten, err
+ }
+ nWritten++
+ }
+ return nWritten, nil
+ },
+ }, {
+ name: "WritePackets",
+ writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) {
+ return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{})
+ },
+ },
+ }
+
+ for _, writer := range writers {
+ t.Run(writer.name, func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ rt := buildRoute(t, nil, test.linkEP())
+
+ var pkts stack.PacketBufferList
+ for i := 0; i < nPackets; i++ {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()),
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(header.UDPMinimumSize)
+ pkts.PushBack(pkt)
+ }
+
+ test.setup(t, rt.Stack())
+
+ nWritten, _ := writer.writePackets(&rt, pkts)
+
+ if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
+ t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
+ }
+ if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
+ t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
+ }
+ if nWritten != test.expectWritten {
+ t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
+ }
+ })
+ }
+ })
+ }
+}
+
+func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ s.CreateNIC(1, linkEP)
+ const (
+ src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ )
+ s.AddAddress(1, ProtocolNumber, src)
+ {
+ subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }})
+ }
+ rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute got %v, want %v", err, nil)
+ }
+ return rt
+}
+
+// limitedEP is a link endpoint that writes up to a certain number of packets
+// before returning errors.
+type limitedEP struct {
+ limit int
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*limitedEP) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 }
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*limitedEP) MaxHeaderLength() uint16 { return 0 }
+
+// LinkAddress implements LinkEndpoint.LinkAddress.
+func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" }
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ if ep.limit == 0 {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ nWritten := ep.limit
+ if nWritten > pkts.Len() {
+ nWritten = pkts.Len()
+ }
+ ep.limit -= nWritten
+ return nWritten, nil
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+ if ep.limit == 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+ ep.limit--
+ return nil
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (*limitedEP) Attach(_ stack.NetworkDispatcher) {}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (*limitedEP) IsAttached() bool { return false }
+
+// Wait implements LinkEndpoint.Wait.
+func (*limitedEP) Wait() {}
+
+// ARPHardwareType implements LinkEndpoint.ARPHardwareType.
+func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther }
+
+// AddHeader implements LinkEndpoint.AddHeader.
+func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
+
+// limitedMatcher is an iptables matcher that matches after a certain number of
+// packets are checked against it.
+type limitedMatcher struct {
+ limit int
+}
+
+// Name implements Matcher.Name.
+func (*limitedMatcher) Name() string {
+ return "limitedMatcher"
+}
+
+// Match implements Matcher.Match.
+func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) {
+ if lm.limit == 0 {
+ return true, false
+ }
+ lm.limit--
+ return false, false
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 0e33cbe92..4a521eca9 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -57,7 +57,72 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- tables: [numTables]Table{
+ v4Tables: [numTables]Table{
+ natID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ },
+ mangleID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
+ },
+ },
+ filterID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ },
+ },
+ v6Tables: [numTables]Table{
natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
@@ -166,25 +231,20 @@ func EmptyNATTable() Table {
// GetTable returns a table by name.
func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- // TODO(gvisor.dev/issue/3549): Enable IPv6.
- if ipv6 {
- return Table{}, false
- }
id, ok := nameToID[name]
if !ok {
return Table{}, false
}
it.mu.RLock()
defer it.mu.RUnlock()
- return it.tables[id], true
+ if ipv6 {
+ return it.v6Tables[id], true
+ }
+ return it.v4Tables[id], true
}
// ReplaceTable replaces or inserts table by name.
func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- // TODO(gvisor.dev/issue/3549): Enable IPv6.
- if ipv6 {
- return tcpip.ErrInvalidOptionValue
- }
id, ok := nameToID[name]
if !ok {
return tcpip.ErrInvalidOptionValue
@@ -198,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Err
it.startReaper(reaperDelay)
}
it.modified = true
- it.tables[id] = table
+ if ipv6 {
+ it.v6Tables[id] = table
+ } else {
+ it.v4Tables[id] = table
+ }
return nil
}
@@ -221,8 +285,15 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
+// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
+// which address and nicName can be gathered. Currently, address is only
+// needed for prerouting and nicName is only needed for output.
+//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+ if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+ return true
+ }
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
it.mu.RLock()
@@ -243,9 +314,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
if tableID == natID && pkt.NatDone {
continue
}
- table := it.tables[tableID]
+ var table Table
+ if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber {
+ table = it.v6Tables[tableID]
+ } else {
+ table = it.v4Tables[tableID]
+ }
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -256,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v {
case RuleAccept:
continue
case RuleDrop:
@@ -351,11 +427,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
case RuleAccept:
return chainAccept
@@ -372,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -398,11 +474,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) {
+ if !rule.Filter.match(pkt, hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -421,7 +497,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr)
}
// OriginalDst returns the original destination of redirected connections. It
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index fbbd2f50f..093ee6881 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"strings"
"sync"
@@ -81,26 +82,25 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects tables, priorities, and modified.
+ // mu protects v4Tables, v6Tables, and modified.
mu sync.RWMutex
-
- // tables maps tableIDs to tables. Holds builtin tables only, not user
- // tables. mu must be locked for accessing.
- tables [numTables]Table
-
- // priorities maps each hook to a list of table names. The order of the
- // list is the order in which each table should be visited for that
- // hook. mu needs to be locked for accessing.
- priorities [NumHooks][]tableID
-
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables. mu must be locked for accessing.
+ v4Tables [numTables]Table
+ v6Tables [numTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
modified bool
+ // priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook. It is immutable.
+ priorities [NumHooks][]tableID
+
connections ConnTrack
- // reaperDone can be signalled to stop the reaper goroutine.
+ // reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
}
@@ -148,7 +148,7 @@ type Rule struct {
Target Target
}
-// IPHeaderFilter holds basic IP filtering data common to every rule.
+// IPHeaderFilter performs basic IP header matching common to every rule.
//
// +stateify savable
type IPHeaderFilter struct {
@@ -196,16 +196,43 @@ type IPHeaderFilter struct {
OutputInterfaceInvert bool
}
-// match returns whether hdr matches the filter.
-func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+// match returns whether pkt matches the filter.
+//
+// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
+// or IPv6 header length.
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+ // Extract header fields.
+ var (
+ // TODO(gvisor.dev/issue/170): Support other filter fields.
+ transProto tcpip.TransportProtocolNumber
+ dstAddr tcpip.Address
+ srcAddr tcpip.Address
+ )
+ switch proto := pkt.NetworkProtocolNumber; proto {
+ case header.IPv4ProtocolNumber:
+ hdr := header.IPv4(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ case header.IPv6ProtocolNumber:
+ hdr := header.IPv6(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ default:
+ panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto))
+ }
+
// Check the transport protocol.
- if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ if fl.CheckProtocol && fl.Protocol != transProto {
return false
}
- // Check the source and destination IPs.
- if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ // Check the addresses.
+ if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) ||
+ !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) {
return false
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 1f1a1426b..204bfc433 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1282,14 +1282,14 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
return
}
- // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
// Loopback traffic skips the prerouting chain.
- if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
+ if !n.isLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
+ n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
return
}
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 17b8beebb..1932aaeb7 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -80,7 +80,7 @@ type PacketBuffer struct {
// data are held in the same underlying buffer storage.
header buffer.Prependable
- // NetworkProtocol is only valid when NetworkHeader is set.
+ // NetworkProtocolNumber is only valid when NetworkHeader is set.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
// numbers in registration APIs that take a PacketBuffer.
NetworkProtocolNumber tcpip.NetworkProtocolNumber
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index b2ddb24ec..464608dee 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -1474,6 +1474,18 @@ type IPStats struct {
// MalformedFragmentsReceived is the total number of IP Fragments that were
// dropped due to the fragment failing validation checks.
MalformedFragmentsReceived *StatCounter
+
+ // IPTablesPreroutingDropped is the total number of IP packets dropped
+ // in the Prerouting chain.
+ IPTablesPreroutingDropped *StatCounter
+
+ // IPTablesInputDropped is the total number of IP packets dropped in
+ // the Input chain.
+ IPTablesInputDropped *StatCounter
+
+ // IPTablesOutputDropped is the total number of IP packets dropped in
+ // the Output chain.
+ IPTablesOutputDropped *StatCounter
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index ad71ff3b6..31116309e 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -74,6 +74,8 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -344,9 +346,14 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch v := opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
}
return nil
}
@@ -415,8 +422,17 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 8bd4e5e37..072601d2d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -83,6 +83,8 @@ type endpoint struct {
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
boundNIC tcpip.NICID
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// lastErrorMu protects lastError.
lastErrorMu sync.Mutex `state:"nosave"`
@@ -298,10 +300,16 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch v := opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ ep.linger = *v
+ ep.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -366,8 +374,17 @@ func (ep *endpoint) LastError() *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrNotSupported
+func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ ep.mu.Lock()
+ *o = ep.linger
+ ep.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrNotSupported
+ }
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index fb03e6047..e37c00523 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -84,6 +84,8 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -511,10 +513,16 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch opt.(type) {
+ switch v := opt.(type) {
case *tcpip.SocketDetachFilterOption:
return nil
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -577,8 +585,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ *o = e.linger
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 234fb95ce..4778e7b1c 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -69,6 +69,7 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index faea7f2bb..120483838 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1317,14 +1317,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
- // The endpoint cannot be written to if it's not connected.
- if !e.EndpointState().connected() {
- switch e.EndpointState() {
- case StateError:
- return 0, e.HardError
- default:
- return 0, tcpip.ErrClosedForSend
- }
+ switch s := e.EndpointState(); {
+ case s == StateError:
+ return 0, e.HardError
+ case !s.connecting() && !s.connected():
+ return 0, tcpip.ErrClosedForSend
+ case s.connecting():
+ // As per RFC793, page 56, a send request arriving when in connecting
+ // state, can be queued to be completed after the state becomes
+ // connected. Return an error code for the caller of endpoint Write to
+ // try again, until the connection handshake is complete.
+ return 0, tcpip.ErrWouldBlock
}
// Check if the connection has already been closed for sends.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 63ec12be8..74a17af79 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -29,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
@@ -506,22 +507,7 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter {
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- // TCP header is variable length, peek at it first.
- hdrLen := header.TCPMinimumSize
- hdr, ok := pkt.Data.PullUp(hdrLen)
- if !ok {
- return false
- }
-
- // If the header has options, pull those up as well.
- if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
- // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
- // packets.
- hdrLen = offset
- }
-
- _, ok = pkt.TransportHeader().Consume(hdrLen)
- return ok
+ return parse.TCP(pkt)
}
// NewProtocol returns a TCP transport protocol.
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index b5d2d0ba6..c78549424 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index b572c39db..518f636f0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -154,6 +154,9 @@ type endpoint struct {
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// +stateify savable
@@ -810,6 +813,11 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
case *tcpip.SocketDetachFilterOption:
return nil
+
+ case *tcpip.LingerOption:
+ e.mu.Lock()
+ e.linger = *v
+ e.mu.Unlock()
}
return nil
}
@@ -966,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
*o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
+ case *tcpip.LingerOption:
+ e.mu.RLock()
+ *o = e.linger
+ e.mu.RUnlock()
+
default:
return tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 3f87e8057..7d6b91a75 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/waiter"
@@ -219,8 +220,7 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
- return ok
+ return parse.UDP(pkt)
}
// NewProtocol returns a UDP transport protocol.