diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 84 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer_test.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 11 |
6 files changed, 127 insertions, 109 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 4fb7e9adb..b7cb54b1d 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -388,28 +388,33 @@ func (ct *ConnTrack) insertConn(conn *conn) { // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { - if pkt.NatDone { - return false - } - switch hook { case Prerouting, Input, Output, Postrouting: default: return false } - transportHeader, ok := getTransportHeader(pkt) - if !ok { + if conn, dir := ct.connFor(pkt); conn != nil { + conn.handlePacket(pkt, hook, dir, r) return false } - conn, dir := ct.connFor(pkt) // Connection not found for the packet. - if conn == nil { - // If this is the last hook in the data path for this packet (Input if - // incoming, Postrouting if outgoing), indicate that a connection should be - // inserted by the end of this hook. - return hook == Input || hook == Postrouting + // + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting +} + +func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, dir direction, r *Route) { + if pkt.NatDone { + return + } + + transportHeader, ok := getTransportHeader(pkt) + if !ok { + return } netHeader := pkt.Network() @@ -425,24 +430,24 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination && dir == dirOriginal { - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr + if cn.manip == manipDestination && dir == dirOriginal { + newPort = cn.reply.srcPort + newAddr = cn.reply.srcAddr pkt.NatDone = true - } else if conn.manip == manipSource && dir == dirReply { - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr + } else if cn.manip == manipSource && dir == dirReply { + newPort = cn.original.srcPort + newAddr = cn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource && dir == dirOriginal { - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr + if cn.manip == manipSource && dir == dirOriginal { + newPort = cn.reply.dstPort + newAddr = cn.reply.dstAddr updateSRCFields = true pkt.NatDone = true - } else if conn.manip == manipDestination && dir == dirReply { - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr + } else if cn.manip == manipDestination && dir == dirReply { + newPort = cn.original.dstPort + newAddr = cn.original.dstAddr updateSRCFields = true pkt.NatDone = true } @@ -451,7 +456,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } if !pkt.NatDone { - return false + return } fullChecksum := false @@ -486,15 +491,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { ) // Update the state of tcb. - conn.mu.Lock() - defer conn.mu.Unlock() + cn.mu.Lock() + defer cn.mu.Unlock() // Mark the connection as having been used recently so it isn't reaped. - conn.lastUsed = time.Now() + cn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(pkt, hook) - - return false + cn.updateLocked(pkt, hook) } // maybeInsertNoop tries to insert a no-op connection entry to keep connections diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 74c9075b4..dcba7eba6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -310,8 +310,8 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) // must be dropped if false is returned. // // Precondition: The packet's network and transport header must be set. -func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool { - return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, outNicName string) bool { + return it.check(Postrouting, pkt, r, addressEP, "" /* inNicName */, outNicName) } // check runs pkt through the rules for hook. It returns true when the packet @@ -431,7 +431,9 @@ func (it *IPTables) startReaper(interval time.Duration) { // // Precondition: The packets' network and transport header must be set. func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { - return it.checkPackets(Output, pkts, r, outNicName) + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckOutput(pkt, r, outNicName) + }) } // CheckPostroutingPackets performs the postrouting hook on the packets. @@ -439,21 +441,16 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa // Returns a map of packets that must be dropped. // // Precondition: The packets' network and transport header must be set. -func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { - return it.checkPackets(Postrouting, pkts, r, outNicName) +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return checkPackets(pkts, func(pkt *PacketBuffer) bool { + return it.CheckPostrouting(pkt, r, addressEP, outNicName) + }) } -// checkPackets runs pkts through the rules for hook and returns a map of -// packets that should not go forward. -// -// NOTE: unlike the Check API the returned map contains packets that should be -// dropped. -// -// Precondition: The packets' network and transport header must be set. -func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok { + if ok := f(pkt); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index e8806ebdb..949c44c9b 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -162,7 +162,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // packet of the connection comes here. Other packets will be // manipulated in connection tracking. if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, r) + conn.handlePacket(pkt, hook, dirOriginal, r) } default: return RuleDrop, 0 @@ -181,15 +181,7 @@ type SNATTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { - // Sanity check. - if st.NetworkProtocol != pkt.NetworkProtocolNumber { - panic(fmt.Sprintf( - "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", - st.NetworkProtocol, pkt.NetworkProtocolNumber)) - } - +func snatAction(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, port uint16, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -200,16 +192,8 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleDrop, 0 } - switch hook { - case Postrouting, Input: - case Prerouting, Output, Forward: - panic(fmt.Sprintf("%s not supported", hook)) - default: - panic(fmt.Sprintf("%s unrecognized", hook)) - } - - port := st.Port - + // TODO(https://gvisor.dev/issue/5773): If the port is in use, pick a + // different port. if port == 0 { switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: @@ -228,13 +212,69 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou // tracking. // // Does nothing if the protocol does not support connection tracking. - if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) + if conn := ct.insertSNATConn(pkt, hook, port, address); conn != nil { + conn.handlePacket(pkt, hook, dirOriginal, r) } return RuleAccept, 0 } +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting, Input: + case Prerouting, Output, Forward: + panic(fmt.Sprintf("%s not supported", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + return snatAction(pkt, ct, hook, r, st.Port, st.Addr) +} + +// MasqueradeTarget modifies the source port/IP in the outgoing packets. +type MasqueradeTarget struct { + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (mt *MasqueradeTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { + // Sanity check. + if mt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "MasqueradeTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + mt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + switch hook { + case Postrouting: + case Prerouting, Input, Forward, Output: + panic(fmt.Sprintf("masquerade target is supported only on postrouting hook; hook = %d", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + // addressEP is expected to be set for the postrouting hook. + ep := addressEP.AcquireOutgoingPrimaryAddress(pkt.Network().DestinationAddress(), false /* allowExpired */) + if ep == nil { + // No address exists that we can use as a source address. + return RuleDrop, 0 + } + + address := ep.AddressWithPrefix().Address + ep.DecRef() + return snatAction(pkt, ct, hook, r, 0 /* port */, address) +} + func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { if updateSRCFields { if fullChecksum { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index bf248ef20..456b0cf80 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -425,13 +425,14 @@ func (d PacketData) PullUp(size int) (tcpipbuffer.View, bool) { return d.pk.buf.PullUp(d.pk.dataOffset(), size) } -// DeleteFront removes count from the beginning of d. It panics if count > -// d.Size(). All backing storage references after the front of the d are -// invalidated. -func (d PacketData) DeleteFront(count int) { - if !d.pk.buf.Remove(d.pk.dataOffset(), count) { - panic("count > d.Size()") +// Consume is the same as PullUp except that is additionally consumes the +// returned bytes. Subsequent PullUp or Consume will not return these bytes. +func (d PacketData) Consume(size int) (tcpipbuffer.View, bool) { + v, ok := d.PullUp(size) + if ok { + d.pk.consumed += size } + return v, ok } // CapLength reduces d to at most length bytes. diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 87b023445..c376ed1a1 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -123,32 +123,6 @@ func TestPacketHeaderPush(t *testing.T) { } } -func TestPacketBufferClone(t *testing.T) { - data := concatViews(makeView(20), makeView(30), makeView(40)) - pk := NewPacketBuffer(PacketBufferOptions{ - // Make a copy of data to make sure our truth data won't be taint by - // PacketBuffer. - Data: buffer.NewViewFromBytes(data).ToVectorisedView(), - }) - - bytesToDelete := 30 - originalSize := data.Size() - - clonedPks := []*PacketBuffer{ - pk.Clone(), - pk.CloneToInbound(), - } - pk.Data().DeleteFront(bytesToDelete) - if got, want := pk.Data().Size(), originalSize-bytesToDelete; got != want { - t.Errorf("original packet was not changed: size expected = %d, got = %d", want, got) - } - for _, clonedPk := range clonedPks { - if got := clonedPk.Data().Size(); got != originalSize { - t.Errorf("cloned packet should not be modified: expected size = %d, got = %d", originalSize, got) - } - } -} - func TestPacketHeaderConsume(t *testing.T) { for _, test := range []struct { name string @@ -461,11 +435,17 @@ func TestPacketBufferData(t *testing.T) { } }) - // DeleteFront + // Consume. for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("Consume%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().DeleteFront(n) + v, ok := pkt.Data().Consume(n) + if !ok { + t.Fatalf("Consume failed") + } + if want := []byte(tc.data)[:n]; !bytes.Equal(v, want) { + t.Fatalf("pkt.Data().Consume(n) = 0x%x, want 0x%x", v, want) + } checkData(t, pkt, []byte(tc.data)[n:]) }) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index cd4137794..c23e91702 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -139,18 +139,15 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().Consume(fakeNetHeaderLen) if !ok { return } - // DeleteFront invalidates slices. Make a copy before trimming. - nb := append([]byte(nil), hdr...) - pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( - tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), - tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + tcpip.Address(hdr[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(hdr[dstAddrOffset:dstAddrOffset+1]), fakeNetNumber, - tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), // Nothing checks the error. nil, /* transport error */ pkt, |