diff options
Diffstat (limited to 'pkg/tcpip/stack/iptables.go')
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 53 |
1 files changed, 20 insertions, 33 deletions
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 5808be685..0baa378ea 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -277,10 +277,7 @@ func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndp return true } - if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { - pkt.tuple = t - t.conn.handlePacket(pkt, hook, nil /* route */) - } + pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt) return it.check(hook, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) } @@ -298,10 +295,6 @@ func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { return true } - if t := pkt.tuple; t != nil { - t.conn.handlePacket(pkt, hook, nil /* route */) - } - ret := it.check(hook, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) if t := pkt.tuple; t != nil { t.conn.finalize() @@ -336,10 +329,7 @@ func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) return true } - if t := it.connections.getConnOrMaybeInsertNoop(pkt); t != nil { - pkt.tuple = t - t.conn.handlePacket(pkt, hook, r) - } + pkt.tuple = it.connections.getConnOrMaybeInsertNoop(pkt) return it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) } @@ -357,10 +347,6 @@ func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, addressEP Addr return true } - if t := pkt.tuple; t != nil { - t.conn.handlePacket(pkt, hook, r) - } - ret := it.check(hook, pkt, r, addressEP, "" /* inNicName */, outNicName) if t := pkt.tuple; t != nil { t.conn.finalize() @@ -396,9 +382,7 @@ func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP Addr // Go through each table containing the hook. priorities := it.priorities[hook] for _, tableID := range priorities { - // If handlePacket already NATed the packet, we don't need to - // check the NAT table. - if tableID == NATID && pkt.NatDone { + if t := pkt.tuple; t != nil && tableID == NATID && t.conn.handlePacket(pkt, hook, r) { continue } var table Table @@ -476,7 +460,7 @@ func (it *IPTables) startReaper(interval time.Duration) { func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { return checkPackets(pkts, func(pkt *PacketBuffer) bool { return it.CheckOutput(pkt, r, outNicName) - }) + }, true /* dnat */) } // CheckPostroutingPackets performs the postrouting hook on the packets. @@ -487,24 +471,27 @@ func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicNa func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, addressEP AddressableEndpoint, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { return checkPackets(pkts, func(pkt *PacketBuffer) bool { return it.CheckPostrouting(pkt, r, addressEP, outNicName) - }) + }, false /* dnat */) } -func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func checkPackets(pkts PacketBufferList, f func(*PacketBuffer) bool, dnat bool) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if !pkt.NatDone { - if ok := f(pkt); !ok { - if drop == nil { - drop = make(map[*PacketBuffer]struct{}) - } - drop[pkt] = struct{}{} + natDone := &pkt.SNATDone + if dnat { + natDone = &pkt.DNATDone + } + + if ok := f(pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) } - if pkt.NatDone { - if natPkts == nil { - natPkts = make(map[*PacketBuffer]struct{}) - } - natPkts[pkt] = struct{}{} + drop[pkt] = struct{}{} + } + if *natDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) } + natPkts[pkt] = struct{}{} } } return drop, natPkts |