summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go14
-rw-r--r--pkg/sentry/socket/netfilter/targets.go3
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go17
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go17
-rw-r--r--pkg/tcpip/header/tcp.go17
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go46
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/conntrack.go480
-rw-r--r--pkg/tcpip/stack/iptables.go51
-rw-r--r--pkg/tcpip/stack/iptables_targets.go115
-rw-r--r--pkg/tcpip/stack/iptables_types.go4
-rw-r--r--pkg/tcpip/stack/nic.go4
-rw-r--r--pkg/tcpip/stack/packet_buffer.go4
-rw-r--r--pkg/tcpip/stack/route.go13
-rw-r--r--pkg/tcpip/stack/stack.go19
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go19
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go6
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD1
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go16
-rw-r--r--test/iptables/filter_output.go2
-rw-r--r--test/iptables/iptables_test.go66
-rw-r--r--test/iptables/iptables_util.go2
-rw-r--r--test/iptables/nat.go103
24 files changed, 858 insertions, 165 deletions
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 72d093aa8..40736fb38 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -251,7 +251,7 @@ func marshalTarget(target stack.Target) []byte {
case stack.ReturnTarget:
return marshalStandardTarget(stack.RuleReturn)
case stack.RedirectTarget:
- return marshalRedirectTarget()
+ return marshalRedirectTarget(tg)
case JumpTarget:
return marshalJumpTarget(tg)
default:
@@ -288,7 +288,7 @@ func marshalErrorTarget(errorName string) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
-func marshalRedirectTarget() []byte {
+func marshalRedirectTarget(rt stack.RedirectTarget) []byte {
// This is a redirect target named redirect
target := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
@@ -298,6 +298,16 @@ func marshalRedirectTarget() []byte {
copy(target.Target.Name[:], redirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ target.NfRange.RangeSize = 1
+ if rt.RangeProtoSpecified {
+ target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ }
+ // Convert port from little endian to big endian.
+ port := make([]byte, 2)
+ binary.LittleEndian.PutUint16(port, rt.MinPort)
+ target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port)
+ binary.LittleEndian.PutUint16(port, rt.MaxPort)
+ target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index c948de876..84abe8d29 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,6 +15,7 @@
package netfilter
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -29,6 +30,6 @@ type JumpTarget struct {
}
// Action implements stack.Target.Action.
-func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) {
+func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 55c0f04f3..57a1e1c12 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -120,14 +120,27 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
if pkt.TransportHeader != nil {
tcpHeader = header.TCP(pkt.TransportHeader)
} else {
+ var length int
+ if hook == stack.Prerouting {
+ // The network header hasn't been parsed yet. We have to do it here.
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ // There's no valid TCP header here, so we hotdrop the
+ // packet.
+ return false, true
+ }
+ h := header.IPv4(hdr)
+ pkt.NetworkHeader = hdr
+ length = int(h.HeaderLength())
+ }
// The TCP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize)
+ hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize)
if !ok {
// There's no valid TCP header here, so we hotdrop the
// packet.
return false, true
}
- tcpHeader = header.TCP(hdr)
+ tcpHeader = header.TCP(hdr[length:])
}
// Check whether the source and destination ports are within the
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 04d03d494..cfa9e621d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -119,14 +119,27 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa
if pkt.TransportHeader != nil {
udpHeader = header.UDP(pkt.TransportHeader)
} else {
+ var length int
+ if hook == stack.Prerouting {
+ // The network header hasn't been parsed yet. We have to do it here.
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ // There's no valid UDP header here, so we hotdrop the
+ // packet.
+ return false, true
+ }
+ h := header.IPv4(hdr)
+ pkt.NetworkHeader = hdr
+ length = int(h.HeaderLength())
+ }
// The UDP header hasn't been parsed yet. We have to do it here.
- hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize)
if !ok {
// There's no valid UDP header here, so we hotdrop the
// packet.
return false, true
}
- udpHeader = header.UDP(hdr)
+ udpHeader = header.UDP(hdr[length:])
}
// Check whether the source and destination ports are within the
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 13480687d..21581257b 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -594,3 +594,20 @@ func AddTCPOptionPadding(options []byte, offset int) int {
}
return paddingToAdd
}
+
+// Acceptable checks if a segment that starts at segSeq and has length segLen is
+// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
+// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
+func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
+ if rcvNxt == rcvAcc {
+ return segLen == 0 && segSeq == rcvNxt
+ }
+ if segLen == 0 {
+ // rcvWnd is incremented by 1 because that is Linux's behavior despite the
+ // RFC.
+ return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
+ }
+ // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
+ // the payload, so we'll accept any payload that overlaps the receieve window.
+ return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc)
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1d61fddad..9db42b2a4 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -252,11 +252,31 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt); !ok {
+ if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok {
// iptables is telling us to drop the packet.
return nil
}
+ if pkt.NatDone {
+ // If the packet is manipulated as per NAT Ouput rules, handle packet
+ // based on destination address and do not send the packet to link layer.
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ if err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ packet := stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
+ ep.HandlePacket(&route, packet)
+ return nil
+ }
+ }
+
if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
// the PacketBuffer's Data field.
@@ -302,8 +322,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
ipt := e.stack.IPTables()
- dropped := ipt.CheckPackets(stack.Output, pkts)
- if len(dropped) == 0 {
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r)
+ if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
// faster WritePackets API directly.
n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -318,6 +338,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := dropped[pkt]; ok {
continue
}
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ if err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+
+ views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
+ views[0] = pkt.Header.View()
+ views = append(views, pkt.Data.Views()...)
+ packet := stack.PacketBuffer{
+ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)}
+ ep.HandlePacket(&route, packet)
+ n++
+ continue
+ }
+ }
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil {
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
return n, err
@@ -407,7 +445,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
- if ok := ipt.Check(stack.Input, pkt); !ok {
+ if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok {
// iptables is telling us to drop the packet.
return
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 5e963a4af..f71073207 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -30,6 +30,7 @@ go_template_instance(
go_library(
name = "stack",
srcs = [
+ "conntrack.go",
"dhcpv6configurationfromndpra_string.go",
"forwarder.go",
"icmp_rate_limit.go",
@@ -62,6 +63,7 @@ go_library(
"//pkg/tcpip/header",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
+ "//pkg/tcpip/transport/tcpconntrack",
"//pkg/waiter",
"@org_golang_x_time//rate:go_default_library",
],
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
new file mode 100644
index 000000000..7d1ede1f2
--- /dev/null
+++ b/pkg/tcpip/stack/conntrack.go
@@ -0,0 +1,480 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "encoding/binary"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
+)
+
+// Connection tracking is used to track and manipulate packets for NAT rules.
+// The connection is created for a packet if it does not exist. Every connection
+// contains two tuples (original and reply). The tuples are manipulated if there
+// is a matching NAT rule. The packet is modified by looking at the tuples in the
+// Prerouting and Output hooks.
+
+// Direction of the tuple.
+type ctDirection int
+
+const (
+ dirOriginal ctDirection = iota
+ dirReply
+)
+
+// Status of connection.
+// TODO(gvisor.dev/issue/170): Add other states of connection.
+type connStatus int
+
+const (
+ connNew connStatus = iota
+ connEstablished
+)
+
+// Manipulation type for the connection.
+type manipType int
+
+const (
+ manipDstPrerouting manipType = iota
+ manipDstOutput
+)
+
+// connTrackMutable is the manipulatable part of the tuple.
+type connTrackMutable struct {
+ // addr is source address of the tuple.
+ addr tcpip.Address
+
+ // port is source port of the tuple.
+ port uint16
+
+ // protocol is network layer protocol.
+ protocol tcpip.NetworkProtocolNumber
+}
+
+// connTrackImmutable is the non-manipulatable part of the tuple.
+type connTrackImmutable struct {
+ // addr is destination address of the tuple.
+ addr tcpip.Address
+
+ // direction is direction (original or reply) of the tuple.
+ direction ctDirection
+
+ // port is destination port of the tuple.
+ port uint16
+
+ // protocol is transport layer protocol.
+ protocol tcpip.TransportProtocolNumber
+}
+
+// connTrackTuple represents the tuple which is created from the
+// packet.
+type connTrackTuple struct {
+ // dst is non-manipulatable part of the tuple.
+ dst connTrackImmutable
+
+ // src is manipulatable part of the tuple.
+ src connTrackMutable
+}
+
+// connTrackTupleHolder is the container of tuple and connection.
+type ConnTrackTupleHolder struct {
+ // conn is pointer to the connection tracking entry.
+ conn *connTrack
+
+ // tuple is original or reply tuple.
+ tuple connTrackTuple
+}
+
+// connTrack is the connection.
+type connTrack struct {
+ // originalTupleHolder contains tuple in original direction.
+ originalTupleHolder ConnTrackTupleHolder
+
+ // replyTupleHolder contains tuple in reply direction.
+ replyTupleHolder ConnTrackTupleHolder
+
+ // status indicates connection is new or established.
+ status connStatus
+
+ // timeout indicates the time connection should be active.
+ timeout time.Duration
+
+ // manip indicates if the packet should be manipulated.
+ manip manipType
+
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection.
+ tcb tcpconntrack.TCB
+
+ // tcbHook indicates if the packet is inbound or outbound to
+ // update the state of tcb.
+ tcbHook Hook
+}
+
+// ConnTrackTable contains a map of all existing connections created for
+// NAT rules.
+type ConnTrackTable struct {
+ // connMu protects connTrackTable.
+ connMu sync.RWMutex
+
+ // connTrackTable maintains a map of tuples needed for connection tracking
+ // for iptables NAT rules. The key for the map is an integer calculated
+ // using seed, source address, destination address, source port and
+ // destination port.
+ CtMap map[uint32]ConnTrackTupleHolder
+
+ // seed is a one-time random value initialized at stack startup
+ // and is used in calculation of hash key for connection tracking
+ // table.
+ Seed uint32
+}
+
+// parseHeaders sets headers in the packet.
+func parseHeaders(pkt *PacketBuffer) {
+ newPkt := pkt.Clone()
+
+ // Set network header.
+ hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ netHeader := header.IPv4(hdr)
+ newPkt.NetworkHeader = hdr
+ length := int(netHeader.HeaderLength())
+
+ // TODO(gvisor.dev/issue/170): Need to support for other
+ // protocols as well.
+ // Set transport header.
+ switch protocol := netHeader.TransportProtocol(); protocol {
+ case header.UDPProtocolNumber:
+ if newPkt.TransportHeader == nil {
+ h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
+ if !ok {
+ return
+ }
+ newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
+ }
+ case header.TCPProtocolNumber:
+ if newPkt.TransportHeader == nil {
+ h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
+ if !ok {
+ return
+ }
+ newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
+ }
+ }
+ pkt.NetworkHeader = newPkt.NetworkHeader
+ pkt.TransportHeader = newPkt.TransportHeader
+}
+
+// packetToTuple converts packet to a tuple in original direction.
+func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
+ var tuple connTrackTuple
+
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ // TODO(gvisor.dev/issue/170): Need to support for other
+ // protocols as well.
+ if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return tuple, tcpip.ErrUnknownProtocol
+ }
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return tuple, tcpip.ErrUnknownProtocol
+ }
+
+ tuple.src.addr = netHeader.SourceAddress()
+ tuple.src.port = tcpHeader.SourcePort()
+ tuple.src.protocol = header.IPv4ProtocolNumber
+
+ tuple.dst.addr = netHeader.DestinationAddress()
+ tuple.dst.port = tcpHeader.DestinationPort()
+ tuple.dst.protocol = netHeader.TransportProtocol()
+
+ return tuple, nil
+}
+
+// getReplyTuple creates reply tuple for the given tuple.
+func getReplyTuple(tuple connTrackTuple) connTrackTuple {
+ var replyTuple connTrackTuple
+ replyTuple.src.addr = tuple.dst.addr
+ replyTuple.src.port = tuple.dst.port
+ replyTuple.src.protocol = tuple.src.protocol
+ replyTuple.dst.addr = tuple.src.addr
+ replyTuple.dst.port = tuple.src.port
+ replyTuple.dst.protocol = tuple.dst.protocol
+ replyTuple.dst.direction = dirReply
+
+ return replyTuple
+}
+
+// makeNewConn creates new connection.
+func makeNewConn(tuple, replyTuple connTrackTuple) connTrack {
+ var conn connTrack
+ conn.status = connNew
+ conn.originalTupleHolder.tuple = tuple
+ conn.originalTupleHolder.conn = &conn
+ conn.replyTupleHolder.tuple = replyTuple
+ conn.replyTupleHolder.conn = &conn
+
+ return conn
+}
+
+// getTupleHash returns hash of the tuple. The fields used for
+// generating hash are seed (generated once for stack), source address,
+// destination address, source port and destination ports.
+func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
+ h := jenkins.Sum32(ct.Seed)
+ h.Write([]byte(tuple.src.addr))
+ h.Write([]byte(tuple.dst.addr))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, tuple.src.port)
+ h.Write([]byte(portBuf))
+ binary.LittleEndian.PutUint16(portBuf, tuple.dst.port)
+ h.Write([]byte(portBuf))
+
+ return h.Sum32()
+}
+
+// connTrackForPacket returns connTrack for packet.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
+// transport protocols.
+func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
+ if hook == Prerouting {
+ // Headers will not be set in Prerouting.
+ // TODO(gvisor.dev/issue/170): Change this after parsing headers
+ // code is added.
+ parseHeaders(pkt)
+ }
+
+ var dir ctDirection
+ tuple, err := packetToTuple(*pkt, hook)
+ if err != nil {
+ return nil, dir
+ }
+
+ ct.connMu.Lock()
+ defer ct.connMu.Unlock()
+
+ connTrackTable := ct.CtMap
+ hash := ct.getTupleHash(tuple)
+
+ var conn *connTrack
+ switch createConn {
+ case true:
+ // If connection does not exist for the hash, create a new
+ // connection.
+ replyTuple := getReplyTuple(tuple)
+ replyHash := ct.getTupleHash(replyTuple)
+ newConn := makeNewConn(tuple, replyTuple)
+ conn = &newConn
+
+ // Add tupleHolders to the map.
+ // TODO(gvisor.dev/issue/170): Need to support collisions using linked list.
+ ct.CtMap[hash] = conn.originalTupleHolder
+ ct.CtMap[replyHash] = conn.replyTupleHolder
+ default:
+ tupleHolder, ok := connTrackTable[hash]
+ if !ok {
+ return nil, dir
+ }
+
+ // If this is the reply of new connection, set the connection
+ // status as ESTABLISHED.
+ conn = tupleHolder.conn
+ if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply {
+ conn.status = connEstablished
+ }
+ if tupleHolder.conn == nil {
+ panic("tupleHolder has null connection tracking entry")
+ }
+
+ dir = tupleHolder.tuple.dst.direction
+ }
+ return conn, dir
+}
+
+// SetNatInfo will manipulate the tuples according to iptables NAT rules.
+func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) {
+ // Get the connection. Connection is always created before this
+ // function is called.
+ conn, _ := ct.connTrackForPacket(pkt, hook, false)
+ if conn == nil {
+ panic("connection should be created to manipulate tuples.")
+ }
+ replyTuple := conn.replyTupleHolder.tuple
+ replyHash := ct.getTupleHash(replyTuple)
+
+ // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to
+ // support changing of address for Prerouting.
+
+ // Change the port as per the iptables rule. This tuple will be used
+ // to manipulate the packet in HandlePacket.
+ conn.replyTupleHolder.tuple.src.addr = rt.MinIP
+ conn.replyTupleHolder.tuple.src.port = rt.MinPort
+ newHash := ct.getTupleHash(conn.replyTupleHolder.tuple)
+
+ // Add the changed tuple to the map.
+ ct.connMu.Lock()
+ defer ct.connMu.Unlock()
+ ct.CtMap[newHash] = conn.replyTupleHolder
+ if hook == Output {
+ conn.replyTupleHolder.conn.manip = manipDstOutput
+ }
+
+ // Delete the old tuple.
+ delete(ct.CtMap, replyHash)
+}
+
+// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook..
+func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ tcpHeader := header.TCP(pkt.TransportHeader)
+
+ // For prerouting redirection, packets going in the original direction
+ // have their destinations modified and replies have their sources
+ // modified.
+ switch dir {
+ case dirOriginal:
+ port := conn.replyTupleHolder.tuple.src.port
+ tcpHeader.SetDestinationPort(port)
+ netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ case dirReply:
+ port := conn.originalTupleHolder.tuple.dst.port
+ tcpHeader.SetSourcePort(port)
+ netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ }
+
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+}
+
+// handlePacketOutput manipulates ports for packets in Output hook.
+func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ tcpHeader := header.TCP(pkt.TransportHeader)
+
+ // For output redirection, packets going in the original direction
+ // have their destinations modified and replies have their sources
+ // modified. For prerouting redirection, we only reach this point
+ // when replying, so packet sources are modified.
+ if conn.manip == manipDstOutput && dir == dirOriginal {
+ port := conn.replyTupleHolder.tuple.src.port
+ tcpHeader.SetDestinationPort(port)
+ netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr)
+ } else {
+ port := conn.originalTupleHolder.tuple.dst.port
+ tcpHeader.SetSourcePort(port)
+ netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr)
+ }
+
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ hdr := &pkt.Header
+ length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+ xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ if gso != nil && gso.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+}
+
+// HandlePacket will manipulate the port and address of the packet if the
+// connection exists.
+func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) {
+ if pkt.NatDone {
+ return
+ }
+
+ if hook != Prerouting && hook != Output {
+ return
+ }
+
+ conn, dir := ct.connTrackForPacket(pkt, hook, false)
+ // Connection or Rule not found for the packet.
+ if conn == nil {
+ return
+ }
+
+ netHeader := header.IPv4(pkt.NetworkHeader)
+ // TODO(gvisor.dev/issue/170): Need to support for other transport
+ // protocols as well.
+ if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if tcpHeader == nil {
+ return
+ }
+
+ switch hook {
+ case Prerouting:
+ handlePacketPrerouting(pkt, conn, dir)
+ case Output:
+ handlePacketOutput(pkt, conn, gso, r, dir)
+ }
+ pkt.NatDone = true
+
+ // Update the state of tcb.
+ // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
+ // other tcp states.
+ var st tcpconntrack.Result
+ if conn.tcb.IsEmpty() {
+ conn.tcb.Init(tcpHeader)
+ conn.tcbHook = hook
+ } else {
+ switch hook {
+ case conn.tcbHook:
+ st = conn.tcb.UpdateStateOutbound(tcpHeader)
+ default:
+ st = conn.tcb.UpdateStateInbound(tcpHeader)
+ }
+ }
+
+ // Delete conntrack if tcp connection is closed.
+ if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset {
+ ct.deleteConnTrack(conn)
+ }
+}
+
+// deleteConnTrack deletes the connection.
+func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) {
+ if conn == nil {
+ return
+ }
+
+ tuple := conn.originalTupleHolder.tuple
+ hash := ct.getTupleHash(tuple)
+ replyTuple := conn.replyTupleHolder.tuple
+ replyHash := ct.getTupleHash(replyTuple)
+
+ ct.connMu.Lock()
+ defer ct.connMu.Unlock()
+
+ delete(ct.CtMap, hash)
+ delete(ct.CtMap, replyHash)
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 6b91159d4..7c3c47d50 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -110,6 +111,10 @@ func DefaultTables() IPTables {
Prerouting: []string{TablenameMangle, TablenameNat},
Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
},
+ connections: ConnTrackTable{
+ CtMap: make(map[uint32]ConnTrackTupleHolder),
+ Seed: generateRandUint32(),
+ },
}
}
@@ -173,12 +178,16 @@ const (
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool {
+ // Packets are manipulated only if connection and matching
+ // NAT rule exists.
+ it.connections.HandlePacket(pkt, hook, gso, r)
+
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
table := it.Tables[tablename]
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -189,7 +198,7 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
case RuleAccept:
continue
case RuleDrop:
@@ -219,26 +228,34 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
- if ok := it.Check(hook, *pkt); !ok {
- if drop == nil {
- drop = make(map[*PacketBuffer]struct{})
+ if !pkt.NatDone {
+ if ok := it.Check(hook, pkt, gso, r, ""); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
+ }
+ drop[pkt] = struct{}{}
+ }
+ if pkt.NatDone {
+ if natPkts == nil {
+ natPkts = make(map[*PacketBuffer]struct{})
+ }
+ natPkts[pkt] = struct{}{}
}
- drop[pkt] = struct{}{}
}
}
- return drop
+ return drop, natPkts
}
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
+// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
// precondition.
-func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict {
case RuleAccept:
return chainAccept
@@ -255,7 +272,7 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -279,9 +296,9 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx
}
// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a
+// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a
// precondition.
-func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
@@ -304,7 +321,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, pkt, "")
+ matches, hotdrop := matcher.Match(hook, *pkt, "")
if hotdrop {
return RuleDrop, 0
}
@@ -315,7 +332,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt)
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
}
func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool {
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 8be61f4b1..36cc6275d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -24,7 +24,7 @@ import (
type AcceptTarget struct{}
// Action implements Target.Action.
-func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -32,7 +32,7 @@ func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
type DropTarget struct{}
// Action implements Target.Action.
-func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -41,7 +41,7 @@ func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -52,7 +52,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) {
+func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -61,7 +61,7 @@ func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) {
type ReturnTarget struct{}
// Action implements Target.Action.
-func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) {
+func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -75,16 +75,16 @@ type RedirectTarget struct {
// redirect.
RangeProtoSpecified bool
- // Min address used to redirect.
+ // MinIP indicates address used to redirect.
MinIP tcpip.Address
- // Max address used to redirect.
+ // MaxIP indicates address used to redirect.
MaxIP tcpip.Address
- // Min port used to redirect.
+ // MinPort indicates port used to redirect.
MinPort uint16
- // Max port used to redirect.
+ // MaxPort indicates port used to redirect.
MaxPort uint16
}
@@ -92,61 +92,76 @@ type RedirectTarget struct {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
- newPkt := pkt.Clone()
+func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
// Set network header.
- headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return RuleDrop, 0
+ if hook == Prerouting {
+ parseHeaders(pkt)
}
- netHeader := header.IPv4(headerView)
- newPkt.NetworkHeader = headerView
- hlen := int(netHeader.HeaderLength())
- tlen := int(netHeader.TotalLength())
- newPkt.Data.TrimFront(hlen)
- newPkt.Data.CapLength(tlen - hlen)
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
+ return RuleDrop, 0
+ }
- // TODO(gvisor.dev/issue/170): Change destination address to
- // loopback or interface address on which the packet was
- // received.
+ // Change the address to localhost (127.0.0.1) in Output and
+ // to primary address of the incoming interface in Prerouting.
+ switch hook {
+ case Output:
+ rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
+ rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ case Prerouting:
+ rt.MinIP = address
+ rt.MaxIP = address
+ default:
+ panic("redirect target is supported only on output and prerouting hooks")
+ }
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
+ netHeader := header.IPv4(pkt.NetworkHeader)
switch protocol := netHeader.TransportProtocol(); protocol {
case header.UDPProtocolNumber:
- var udpHeader header.UDP
- if newPkt.TransportHeader != nil {
- udpHeader = header.UDP(newPkt.TransportHeader)
- } else {
- if pkt.Data.Size() < header.UDPMinimumSize {
- return RuleDrop, 0
- }
- hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize)
- if !ok {
- return RuleDrop, 0
+ udpHeader := header.UDP(pkt.TransportHeader)
+ udpHeader.SetDestinationPort(rt.MinPort)
+
+ // Calculate UDP checksum and set it.
+ if hook == Output {
+ udpHeader.SetChecksum(0)
+ hdr := &pkt.Header
+ length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength())
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(protocol, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHeader.SetChecksum(0)
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
- udpHeader = header.UDP(hdr)
}
- udpHeader.SetDestinationPort(rt.MinPort)
+ // Change destination address.
+ netHeader.SetDestinationAddress(rt.MinIP)
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ pkt.NatDone = true
case header.TCPProtocolNumber:
- var tcpHeader header.TCP
- if newPkt.TransportHeader != nil {
- tcpHeader = header.TCP(newPkt.TransportHeader)
- } else {
- if pkt.Data.Size() < header.TCPMinimumSize {
- return RuleDrop, 0
- }
- hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize)
- if !ok {
- return RuleDrop, 0
- }
- tcpHeader = header.TCP(hdr)
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule.
+ // Only the first packet of the connection comes here.
+ // Other packets will be manipulated in connection tracking.
+ if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil {
+ ct.SetNatInfo(pkt, rt, hook)
+ ct.HandlePacket(pkt, hook, gso, r)
}
- // TODO(gvisor.dev/issue/170): Need to recompute checksum
- // and implement nat connection tracking to support TCP.
- tcpHeader.SetDestinationPort(rt.MinPort)
default:
return RuleDrop, 0
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 2ffb55f2a..1bb0ba1bd 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -82,6 +82,8 @@ type IPTables struct {
// list is the order in which each table should be visited for that
// hook.
Priorities map[Hook][]string
+
+ connections ConnTrackTable
}
// A Table defines a set of chains and hooks into the network stack. It is
@@ -176,5 +178,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
- Action(packet PacketBuffer) (RuleVerdict, int)
+ Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7b54919bb..8f4c1fe42 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1230,8 +1230,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
if protocol == header.IPv4ProtocolNumber {
+ // iptables filtering.
ipt := n.stack.IPTables()
- if ok := ipt.Check(Prerouting, pkt); !ok {
+ address := n.primaryAddress(protocol)
+ if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok {
// iptables is telling us to drop the packet.
return
}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 9ff80ab24..926df4d7b 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -72,6 +72,10 @@ type PacketBuffer struct {
EgressRoute *Route
GSOOptions *GSO
NetworkProtocolNumber tcpip.NetworkProtocolNumber
+
+ // NatDone indicates if the packet has been manipulated as per NAT
+ // iptables rule.
+ NatDone bool
}
// Clone makes a copy of pk. It clones the Data field, which creates a new
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 53148dc03..150297ab9 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -261,3 +261,16 @@ func (r *Route) MakeLoopedRoute() Route {
func (r *Route) Stack() *Stack {
return r.ref.stack()
}
+
+// ReverseRoute returns new route with given source and destination address.
+func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
+ return Route{
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ ref: r.ref,
+ Loop: r.Loop,
+ }
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 4a2dc3dc6..e33fae4eb 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1885,3 +1885,22 @@ func generateRandInt64() int64 {
}
return v
}
+
+// FindNetworkEndpoint returns the network endpoint for the given address.
+func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, nic := range s.nics {
+ id := NetworkEndpointID{address}
+
+ if ref, ok := nic.mu.endpoints[id]; ok {
+ nic.mu.RLock()
+ defer nic.mu.RUnlock()
+
+ // An endpoint with this id exists, check if it can be used and return it.
+ return ref.ep, nil
+ }
+ }
+ return nil, tcpip.ErrBadAddress
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f2aa69069..f38eb6833 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -115,7 +115,7 @@ go_test(
size = "small",
srcs = ["rcv_test.go"],
deps = [
- ":tcp",
+ "//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
],
)
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index a4b73b588..6fe97fefd 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -70,24 +70,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale
// acceptable checks if the segment sequence number range is acceptable
// according to the table on page 26 of RFC 793.
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc)
-}
-
-// Acceptable checks if a segment that starts at segSeq and has length segLen is
-// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
-// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
-func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
- if rcvNxt == rcvAcc {
- return segLen == 0 && segSeq == rcvNxt
- }
- if segLen == 0 {
- // rcvWnd is incremented by 1 because that is Linux's behavior despite the
- // RFC.
- return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
- }
- // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
- // the payload, so we'll accept any payload that overlaps the receieve window.
- return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc)
+ return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc)
}
// getSendParams returns the parameters needed by the sender when building
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
index dc02729ce..c9eeff935 100644
--- a/pkg/tcpip/transport/tcp/rcv_test.go
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -17,8 +17,8 @@ package rcv_test
import (
"testing"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
func TestAcceptable(t *testing.T) {
@@ -67,8 +67,8 @@ func TestAcceptable(t *testing.T) {
{105, 2, 108, 108, false},
{105, 2, 109, 109, false},
} {
- if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
- t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
+ if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
+ t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
}
}
}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 2025ff757..3ad6994a7 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -9,7 +9,6 @@ go_library(
deps = [
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
- "//pkg/tcpip/transport/tcp",
],
)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 30d05200f..12bc1b5b5 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -20,7 +20,6 @@ package tcpconntrack
import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
- "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
)
// Result is returned when the state of a TCB is updated in response to an
@@ -312,7 +311,7 @@ type stream struct {
// the window is zero, if it's a packet with no payload and sequence number
// equal to una.
func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- return tcp.Acceptable(segSeq, segLen, s.una, s.end)
+ return header.Acceptable(segSeq, segLen, s.una, s.end)
}
// closed determines if the stream has already been closed. This happens when
@@ -338,3 +337,16 @@ func logicalLen(tcp header.TCP) seqnum.Size {
}
return l
}
+
+// IsEmpty returns true if tcb is not initialized.
+func (t *TCB) IsEmpty() bool {
+ if t.inbound != (stream{}) || t.outbound != (stream{}) {
+ return false
+ }
+
+ if t.firstFin != nil || t.state != ResultDrop {
+ return false
+ }
+
+ return true
+}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index f6d974b85..b1382d353 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -42,7 +42,7 @@ func (FilterOutputDropTCPDestPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil {
return err
}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 334d8e676..63a862d35 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -115,30 +115,6 @@ func TestFilterInputDropOnlyUDP(t *testing.T) {
singleTest(t, FilterInputDropOnlyUDP{})
}
-func TestNATRedirectUDPPort(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
- singleTest(t, NATRedirectUDPPort{})
-}
-
-func TestNATRedirectTCPPort(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
- singleTest(t, NATRedirectTCPPort{})
-}
-
-func TestNATDropUDP(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
- singleTest(t, NATDropUDP{})
-}
-
-func TestNATAcceptAll(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
- singleTest(t, NATAcceptAll{})
-}
-
func TestFilterInputDropTCPDestPort(t *testing.T) {
singleTest(t, FilterInputDropTCPDestPort{})
}
@@ -164,14 +140,10 @@ func TestFilterInputReturnUnderflow(t *testing.T) {
}
func TestFilterOutputDropTCPDestPort(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, FilterOutputDropTCPDestPort{})
}
func TestFilterOutputDropTCPSrcPort(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, FilterOutputDropTCPSrcPort{})
}
@@ -235,44 +207,54 @@ func TestOutputInvertDestination(t *testing.T) {
singleTest(t, FilterOutputInvertDestination{})
}
+func TestNATPreRedirectUDPPort(t *testing.T) {
+ singleTest(t, NATPreRedirectUDPPort{})
+}
+
+func TestNATPreRedirectTCPPort(t *testing.T) {
+ singleTest(t, NATPreRedirectTCPPort{})
+}
+
+func TestNATOutRedirectUDPPort(t *testing.T) {
+ singleTest(t, NATOutRedirectUDPPort{})
+}
+
+func TestNATOutRedirectTCPPort(t *testing.T) {
+ singleTest(t, NATOutRedirectTCPPort{})
+}
+
+func TestNATDropUDP(t *testing.T) {
+ singleTest(t, NATDropUDP{})
+}
+
+func TestNATAcceptAll(t *testing.T) {
+ singleTest(t, NATAcceptAll{})
+}
+
func TestNATOutRedirectIP(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATOutRedirectIP{})
}
func TestNATOutDontRedirectIP(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATOutDontRedirectIP{})
}
func TestNATOutRedirectInvert(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATOutRedirectInvert{})
}
func TestNATPreRedirectIP(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATPreRedirectIP{})
}
func TestNATPreDontRedirectIP(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATPreDontRedirectIP{})
}
func TestNATPreRedirectInvert(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATPreRedirectInvert{})
}
func TestNATRedirectRequiresProtocol(t *testing.T) {
- // TODO(gvisor.dev/issue/170): Enable when supported.
- t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).")
singleTest(t, NATRedirectRequiresProtocol{})
}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 2a00677be..2f988cd18 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -151,7 +151,7 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error {
return err
}
if err := testutil.Poll(callback, timeout); err != nil {
- return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err)
+ return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err)
}
return nil
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 40096901c..0a10ce7fe 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -26,8 +26,10 @@ const (
)
func init() {
- RegisterTestCase(NATRedirectUDPPort{})
- RegisterTestCase(NATRedirectTCPPort{})
+ RegisterTestCase(NATPreRedirectUDPPort{})
+ RegisterTestCase(NATPreRedirectTCPPort{})
+ RegisterTestCase(NATOutRedirectUDPPort{})
+ RegisterTestCase(NATOutRedirectTCPPort{})
RegisterTestCase(NATDropUDP{})
RegisterTestCase(NATAcceptAll{})
RegisterTestCase(NATPreRedirectIP{})
@@ -39,16 +41,16 @@ func init() {
RegisterTestCase(NATRedirectRequiresProtocol{})
}
-// NATRedirectUDPPort tests that packets are redirected to different port.
-type NATRedirectUDPPort struct{}
+// NATPreRedirectUDPPort tests that packets are redirected to different port.
+type NATPreRedirectUDPPort struct{}
// Name implements TestCase.Name.
-func (NATRedirectUDPPort) Name() string {
- return "NATRedirectUDPPort"
+func (NATPreRedirectUDPPort) Name() string {
+ return "NATPreRedirectUDPPort"
}
// ContainerAction implements TestCase.ContainerAction.
-func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
+func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error {
if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
return err
}
@@ -61,33 +63,53 @@ func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
}
// LocalAction implements TestCase.LocalAction.
-func (NATRedirectUDPPort) LocalAction(ip net.IP) error {
+func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
-// NATRedirectTCPPort tests that connections are redirected on specified ports.
-type NATRedirectTCPPort struct{}
+// NATPreRedirectTCPPort tests that connections are redirected on specified ports.
+type NATPreRedirectTCPPort struct{}
// Name implements TestCase.Name.
-func (NATRedirectTCPPort) Name() string {
- return "NATRedirectTCPPort"
+func (NATPreRedirectTCPPort) Name() string {
+ return "NATPreRedirectTCPPort"
}
// ContainerAction implements TestCase.ContainerAction.
-func (NATRedirectTCPPort) ContainerAction(ip net.IP) error {
- if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
return err
}
// Listen for TCP packets on redirect port.
- return listenTCP(redirectPort, sendloopDuration)
+ return listenTCP(acceptPort, sendloopDuration)
}
// LocalAction implements TestCase.LocalAction.
-func (NATRedirectTCPPort) LocalAction(ip net.IP) error {
+func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error {
return connectTCP(ip, dropPort, sendloopDuration)
}
+// NATOutRedirectUDPPort tests that packets are redirected to different port.
+type NATOutRedirectUDPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectUDPPort) Name() string {
+ return "NATOutRedirectUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectUDPPort) ContainerAction(ip net.IP) error {
+ dest := []byte{200, 0, 0, 1}
+ return loopbackTest(dest, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectUDPPort) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
// NATDropUDP tests that packets are not received in ports other than redirect
// port.
type NATDropUDP struct{}
@@ -329,3 +351,52 @@ func loopbackTest(dest net.IP, args ...string) error {
// sendCh will always take the full sendloop time.
return <-sendCh
}
+
+// NATOutRedirectTCPPort tests that connections are redirected on specified ports.
+type NATOutRedirectTCPPort struct{}
+
+// Name implements TestCase.Name.
+func (NATOutRedirectTCPPort) Name() string {
+ return "NATOutRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error {
+ if err := natTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ timeout := 20 * time.Second
+ dest := []byte{127, 0, 0, 1}
+ localAddr := net.TCPAddr{
+ IP: dest,
+ Port: acceptPort,
+ }
+
+ // Starts listening on port.
+ lConn, err := net.ListenTCP("tcp", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer lConn.Close()
+
+ // Accept connections on port.
+ lConn.SetDeadline(time.Now().Add(timeout))
+ err = connectTCP(ip, dropPort, timeout)
+ if err != nil {
+ return err
+ }
+
+ conn, err := lConn.AcceptTCP()
+ if err != nil {
+ return err
+ }
+ conn.Close()
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error {
+ return nil
+}