summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/conntrack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack/conntrack.go')
-rw-r--r--pkg/tcpip/stack/conntrack.go40
1 files changed, 24 insertions, 16 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 457f0c89b..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -196,13 +196,14 @@ type bucket struct {
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
+
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
@@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
@@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
@@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
@@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
@@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
transProto: header.TCPProtocolNumber,
- netProto: header.IPv4ProtocolNumber,
+ netProto: netProto,
}
conn, _ := ct.connForTID(tid)
if conn == nil {