summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/conntrack.go156
1 files changed, 137 insertions, 19 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index a3f403855..4a28be585 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -209,19 +209,120 @@ type bucket struct {
tuples tupleList
}
-func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.ChecksummableTransport, isICMPError bool, ok bool) {
switch pkt.TransportProtocolNumber {
case header.TCPProtocolNumber:
if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
- return tcpHeader, true
+ return pkt.Network(), tcpHeader, false, true
}
case header.UDPProtocolNumber:
if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
- return udpHeader, true
+ return pkt.Network(), udpHeader, false, true
+ }
+ case header.ICMPv4ProtocolNumber:
+ h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
+ if !ok {
+ panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize))
+ }
+
+ ipv4 := header.IPv4(h)
+ if ipv4.HeaderLength() > header.IPv4MinimumSize {
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ panic("should have dropped packets with IPv4 options")
+ }
+
+ switch pkt.tuple.id().transProto {
+ case header.TCPProtocolNumber:
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize)
+ if !ok {
+ return nil, nil, false, false
+ }
+ netHeader := header.IPv4(netAndTransHeader)
+ return netHeader, header.TCP(netHeader.Payload()), true, true
+ case header.UDPProtocolNumber:
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.UDPMinimumSize)
+ if !ok {
+ return nil, nil, false, false
+ }
+ netHeader := header.IPv4(netAndTransHeader)
+ return netHeader, header.UDP(netHeader.Payload()), true, true
}
}
- return nil, false
+ return nil, nil, false, false
+}
+
+func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID {
+ return tupleID{
+ srcAddr: netHdr.SourceAddress(),
+ srcPort: transHdr.SourcePort(),
+ dstAddr: netHdr.DestinationAddress(),
+ dstPort: transHdr.DestinationPort(),
+ transProto: transProto,
+ netProto: netProto,
+ }
+}
+
+func getTupleIDForPacketInICMPError(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }
+}
+
+func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if transHeader := header.TCP(pkt.TransportHeader().View()); len(transHeader) >= header.TCPMinimumSize {
+ return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true
+ }
+ case header.UDPProtocolNumber:
+ if transHeader := header.UDP(pkt.TransportHeader().View()); len(transHeader) >= header.UDPMinimumSize {
+ return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true
+ }
+ case header.ICMPv4ProtocolNumber:
+ icmp := header.ICMPv4(pkt.TransportHeader().View())
+ if len(icmp) < header.ICMPv4MinimumSize {
+ return tupleID{}, false, false
+ }
+
+ switch icmp.Type() {
+ case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem:
+ default:
+ return tupleID{}, false, false
+ }
+
+ h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return tupleID{}, false, false
+ }
+
+ ipv4 := header.IPv4(h)
+ if ipv4.HeaderLength() > header.IPv4MinimumSize {
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ return tupleID{}, false, false
+ }
+ switch ipv4.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ if netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize); ok {
+ netHdr := header.IPv4(netAndTransHeader)
+ return getTupleIDForPacketInICMPError(netHdr, header.IPv4ProtocolNumber, header.TCP(netHdr.Payload()), header.TCPProtocolNumber), true, true
+ }
+ case header.UDPProtocolNumber:
+ if netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.UDPMinimumSize); ok {
+ netHdr := header.IPv4(netAndTransHeader)
+ return getTupleIDForPacketInICMPError(netHdr, header.IPv4ProtocolNumber, header.UDP(netHdr.Payload()), header.UDPProtocolNumber), true, true
+ }
+ }
+ }
+
+ return tupleID{}, false, false
}
func (ct *ConnTrack) init() {
@@ -231,21 +332,11 @@ func (ct *ConnTrack) init() {
}
func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
- netHeader := pkt.Network()
- transportHeader, ok := getTransportHeader(pkt)
+ tid, isICMPError, ok := getTupleID(pkt)
if !ok {
return nil
}
- tid := tupleID{
- srcAddr: netHeader.SourceAddress(),
- srcPort: transportHeader.SourcePort(),
- dstAddr: netHeader.DestinationAddress(),
- dstPort: transportHeader.DestinationPort(),
- transProto: pkt.TransportProtocolNumber,
- netProto: pkt.NetworkProtocolNumber,
- }
-
bktID := ct.bucket(tid)
ct.mu.RLock()
@@ -257,6 +348,11 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
return t
}
+ if isICMPError {
+ // Do not create a noop entry in response to an ICMP error.
+ return nil
+ }
+
bkt.mu.Lock()
defer bkt.mu.Unlock()
@@ -407,7 +503,7 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool)
//
// Returns true if the packet can skip the NAT table.
func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
- transportHeader, ok := getTransportHeader(pkt)
+ netHdr, transHdr, isICMPError, ok := getHeaders(pkt)
if !ok {
return false
}
@@ -498,9 +594,9 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
}
rewritePacket(
- pkt.Network(),
- transportHeader,
- !dnat,
+ netHdr,
+ transHdr,
+ !dnat != isICMPError,
fullChecksum,
updatePseudoHeader,
newPort,
@@ -508,6 +604,28 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
)
*natDone = true
+
+ if !isICMPError {
+ return true
+ }
+
+ // We performed NAT on (erroneous) packet that triggered an ICMP response, but
+ // not the ICMP packet itself.
+ switch pkt.TransportProtocolNumber {
+ case header.ICMPv4ProtocolNumber:
+ icmp := header.ICMPv4(pkt.TransportHeader().View())
+ // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum.
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().AsRange().Checksum()))
+
+ network := header.IPv4(pkt.NetworkHeader().View())
+ if dnat {
+ network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr)
+ } else {
+ network.SetSourceAddressWithChecksumUpdate(tid.dstAddr)
+ }
+ }
+
return true
}