summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/conntrack.go65
-rw-r--r--pkg/tcpip/stack/iptables.go25
-rw-r--r--pkg/tcpip/stack/iptables_targets.go84
-rw-r--r--pkg/tcpip/stack/packet_buffer.go13
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go38
-rw-r--r--pkg/tcpip/stack/stack_test.go11
6 files changed, 127 insertions, 109 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 4fb7e9adb..b7cb54b1d 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -388,28 +388,33 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// connection exists. Returns whether, after the packet traverses the tables,
// it should create a new entry in the table.
func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
- if pkt.NatDone {
- return false
- }
-
switch hook {
case Prerouting, Input, Output, Postrouting:
default:
return false
}
- transportHeader, ok := getTransportHeader(pkt)
- if !ok {
+ if conn, dir := ct.connFor(pkt); conn != nil {
+ conn.handlePacket(pkt, hook, dir, r)
return false
}
- conn, dir := ct.connFor(pkt)
// Connection not found for the packet.
- if conn == nil {
- // If this is the last hook in the data path for this packet (Input if
- // incoming, Postrouting if outgoing), indicate that a connection should be
- // inserted by the end of this hook.
- return hook == Input || hook == Postrouting
+ //
+ // If this is the last hook in the data path for this packet (Input if
+ // incoming, Postrouting if outgoing), indicate that a connection should be
+ // inserted by the end of this hook.
+ return hook == Input || hook == Postrouting
+}
+
+func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Route) {
+ if pkt.NatDone {
+ return
+ }
+
+ transportHeader, ok := getTransportHeader(pkt)
+ if !ok {
+ return
}
netHeader := pkt.Network()
@@ -425,24 +430,24 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
switch hook {
case Prerouting, Output:
- if conn.manip == manipDestination && dir == dirOriginal {
- newPort = conn.reply.srcPort
- newAddr = conn.reply.srcAddr
+ if cn.manip == manipDestination && dir == dirOriginal {
+ newPort = cn.reply.srcPort
+ newAddr = cn.reply.srcAddr
pkt.NatDone = true
- } else if conn.manip == manipSource && dir == dirReply {
- newPort = conn.original.srcPort
- newAddr = conn.original.srcAddr
+ } else if cn.manip == manipSource && dir == dirReply {
+ newPort = cn.original.srcPort
+ newAddr = cn.original.srcAddr
pkt.NatDone = true
}
case Input, Postrouting:
- if conn.manip == manipSource && dir == dirOriginal {
- newPort = conn.reply.dstPort
- newAddr = conn.reply.dstAddr
+ if cn.manip == manipSource && dir == dirOriginal {
+ newPort = cn.reply.dstPort
+ newAddr = cn.reply.dstAddr
updateSRCFields = true
pkt.NatDone = true
- } else if conn.manip == manipDestination && dir == dirReply {
- newPort = conn.original.dstPort
- newAddr = conn.original.dstAddr
+ } else if cn.manip == manipDestination && dir == dirReply {
+ newPort = cn.original.dstPort
+ newAddr = cn.original.dstAddr
updateSRCFields = true
pkt.NatDone = true
}
@@ -451,7 +456,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
}
if !pkt.NatDone {
- return false
+ return
}
fullChecksum := false
@@ -486,15 +491,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
)
// Update the state of tcb.
- conn.mu.Lock()
- defer conn.mu.Unlock()
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
// Mark the connection as having been used recently so it isn't reaped.
- conn.lastUsed = time.Now()
+ cn.lastUsed = time.Now()
// Update connection state.
- conn.updateLocked(pkt, hook)
-
- return false
+ cn.updateLocked(pkt, hook)
}
// maybeInsertNoop tries to insert a no-op connection entry to keep connections
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 74c9075b4..dcba7eba6 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -310,8 +310,8 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string)
// must be dropped if false is returned.
//
// Precondition: The packet's network and transport header must be set.
-func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool {
- return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName)
+func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool {
+ return it.check(Postrouting, pkt, r, addressEP, "" /* inNicName */, outNicName)
}
// check runs pkt through the rules for hook. It returns true when the packet
@@ -431,7 +431,9 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// Precondition: The packets' network and transport header must be set.
func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
- return it.checkPackets(Output, pkts, r, outNicName)
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckOutput(pkt, r, outNicName)
+ })
}
// CheckPostroutingPackets performs the postrouting hook on the packets.
@@ -439,21 +441,16 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa
// Returns a map of packets that must be dropped.
//
// Precondition: The packets' network and transport header must be set.
-func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
- return it.checkPackets(Postrouting, pkts, r, outNicName)
+func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ return checkPackets(pkts, func(pkt *PacketBuffer) bool {
+ return it.CheckPostrouting(pkt, r, addressEP, outNicName)
+ })
}
-// checkPackets runs pkts through the rules for hook and returns a map of
-// packets that should not go forward.
-//
-// NOTE: unlike the Check API the returned map contains packets that should be
-// dropped.
-//
-// Precondition: The packets' network and transport header must be set.
-func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok {
+ if ok := f(pkt); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index e8806ebdb..949c44c9b 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -162,7 +162,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
- ct.handlePacket(pkt, hook, r)
+ conn.handlePacket(pkt, hook, dirOriginal, r)
}
default:
return RuleDrop, 0
@@ -181,15 +181,7 @@ type SNATTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// Action implements Target.Action.
-func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
- // Sanity check.
- if st.NetworkProtocol != pkt.NetworkProtocolNumber {
- panic(fmt.Sprintf(
- "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
- st.NetworkProtocol, pkt.NetworkProtocolNumber))
- }
-
+func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -200,16 +192,8 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleDrop, 0
}
- switch hook {
- case Postrouting, Input:
- case Prerouting, Output, Forward:
- panic(fmt.Sprintf("%s not supported", hook))
- default:
- panic(fmt.Sprintf("%s unrecognized", hook))
- }
-
- port := st.Port
-
+ // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a
+ // different port.
if port == 0 {
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
@@ -228,13 +212,69 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
// tracking.
//
// Does nothing if the protocol does not support connection tracking.
- if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil {
- ct.handlePacket(pkt, hook, r)
+ if conn := ct.insertSNATConn(pkt, hook, port, address); conn != nil {
+ conn.handlePacket(pkt, hook, dirOriginal, r)
}
return RuleAccept, 0
}
+// Action implements Target.Action.
+func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if st.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ st.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Postrouting, Input:
+ case Prerouting, Output, Forward:
+ panic(fmt.Sprintf("%s not supported", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ return snatAction(pkt, ct, hook, r, st.Port, st.Addr)
+}
+
+// MasqueradeTarget modifies the source port/IP in the outgoing packets.
+type MasqueradeTarget struct {
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
+ NetworkProtocol tcpip.NetworkProtocolNumber
+}
+
+// Action implements Target.Action.
+func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) {
+ // Sanity check.
+ if mt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ mt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
+ switch hook {
+ case Postrouting:
+ case Prerouting, Input, Forward, Output:
+ panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook))
+ default:
+ panic(fmt.Sprintf("%s unrecognized", hook))
+ }
+
+ // addressEP is expected to be set for the postrouting hook.
+ ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */)
+ if ep == nil {
+ // No address exists that we can use as a source address.
+ return RuleDrop, 0
+ }
+
+ address := ep.AddressWithPrefix().Address
+ ep.DecRef()
+ return snatAction(pkt, ct, hook, r, 0 /* port */, address)
+}
+
func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
if updateSRCFields {
if fullChecksum {
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index bf248ef20..456b0cf80 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -425,13 +425,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) {
return d.pk.buf.PullUp(d.pk.dataOffset(), size)
}
-// DeleteFront removes count from the beginning of d. It panics if count >
-// d.Size(). All backing storage references after the front of the d are
-// invalidated.
-func (d PacketData) DeleteFront(count int) {
- if !d.pk.buf.Remove(d.pk.dataOffset(), count) {
- panic("count > d.Size()")
+// Consume is the same as PullUp except that is additionally consumes the
+// returned bytes. Subsequent PullUp or Consume will not return these bytes.
+func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) {
+ v, ok := d.PullUp(size)
+ if ok {
+ d.pk.consumed += size
}
+ return v, ok
}
// CapLength reduces d to at most length bytes.
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
index 87b023445..c376ed1a1 100644
--- a/pkg/tcpip/stack/packet_buffer_test.go
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) {
}
}
-func TestPacketBufferClone(t *testing.T) {
- data := concatViews(makeView(20), makeView(30), makeView(40))
- pk := NewPacketBuffer(PacketBufferOptions{
- // Make a copy of data to make sure our truth data won't be taint by
- // PacketBuffer.
- Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
- })
-
- bytesToDelete := 30
- originalSize := data.Size()
-
- clonedPks := []*PacketBuffer{
- pk.Clone(),
- pk.CloneToInbound(),
- }
- pk.Data().DeleteFront(bytesToDelete)
- if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want {
- t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got)
- }
- for _, clonedPk := range clonedPks {
- if got := clonedPk.Data().Size(); got != originalSize {
- t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got)
- }
- }
-}
-
func TestPacketHeaderConsume(t *testing.T) {
for _, test := range []struct {
name string
@@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) {
}
})
- // DeleteFront
+ // Consume.
for _, n := range []int{1, len(tc.data)} {
- t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) {
+ t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) {
pkt := tc.makePkt(t)
- pkt.Data().DeleteFront(n)
+ v, ok := pkt.Data().Consume(n)
+ if !ok {
+ t.Fatalf("Consume failed")
+ }
+ if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) {
+ t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want)
+ }
checkData(t, pkt, []byte(tc.data)[n:])
})
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index cd4137794..c23e91702 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Handle control packets.
if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
- hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen)
+ hdr, ok := pkt.Data().Consume(fakeNetHeaderLen)
if !ok {
return
}
- // DeleteFront invalidates slices. Make a copy before trimming.
- nb := append([]byte(nil), hdr...)
- pkt.Data().DeleteFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportError(
- tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
- tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]),
fakeNetNumber,
- tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]),
// Nothing checks the error.
nil, /* transport error */
pkt,