diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-06-24 22:38:14 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-06-24 22:45:17 -0700 |
commit | 1f113b96e68fed452e40855db0cf3efa24b2b9b6 (patch) | |
tree | 7af96816bd25d99469b90b77d9c69204b3559a33 | |
parent | 49986674aaefd1aff50cc35cf1089206e174325c (diff) |
Incrementally update checksum when NAT-ing
...instead of calculating a fresh checksum to avoid re-calcalculating
a checksum on unchanged bytes.
Fixes #5340.
PiperOrigin-RevId: 381403888
-rw-r--r-- | pkg/tcpip/header/checksum.go | 62 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 203 | ||||
-rw-r--r-- | pkg/tcpip/header/interfaces.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/header/udp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/stack/conntrack.go | 51 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_targets.go | 97 |
8 files changed, 465 insertions, 56 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 6aa9acfa8..e2c85e220 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -18,6 +18,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip. return Checksum([]byte{0, uint8(protocol)}, xsum) } + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + return ChecksumCombine(xsum, ChecksumCombine(new, ^old)) +} + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { + const uint16Bytes = 2 + + if len(old) != len(new) { + panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new))) + } + + if len(old)%uint16Bytes != 0 { + panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old))) + } + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(old) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1])) + old = old[uint16Bytes:] + new = new[uint16Bytes:] + } + + return xsum +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index d267dabd0..3445511f4 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -23,6 +23,7 @@ import ( "sync" "testing" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) { }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } + +func randomAddress(size int) tcpip.Address { + s := make([]byte, size) + for i := 0; i < size; i++ { + s[i] = byte(rand.Uint32()) + } + return tcpip.Address(s) +} + +func TestChecksummableNetworkUpdateAddress(t *testing.T) { + tests := []struct { + name string + update func(header.IPv4, tcpip.Address) + }{ + { + name: "SetSourceAddressWithChecksumUpdate", + update: header.IPv4.SetSourceAddressWithChecksumUpdate, + }, + { + name: "SetDestinationAddressWithChecksumUpdate", + update: header.IPv4.SetDestinationAddressWithChecksumUpdate, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + var origBytes [header.IPv4MinimumSize]byte + header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ + TOS: 1, + TotalLength: header.IPv4MinimumSize, + ID: 2, + Flags: 3, + FragmentOffset: 4, + TTL: 5, + Protocol: 6, + Checksum: 0, + SrcAddr: randomAddress(header.IPv4AddressSize), + DstAddr: randomAddress(header.IPv4AddressSize), + }) + + addr := randomAddress(header.IPv4AddressSize) + + bytesCopy := origBytes + h := header.IPv4(bytesCopy[:]) + origXSum := h.CalculateChecksum() + h.SetChecksum(^origXSum) + + test.update(h, addr) + got := ^h.Checksum() + h.SetChecksum(0) + want := h.CalculateChecksum() + if got != want { + t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) + } + } + }) + } +} + +func TestChecksummableTransportUpdatePort(t *testing.T) { + // The fields in the pseudo header is not tested here so we just use 0. + const pseudoHeaderXSum = 0 + + tests := []struct { + name string + transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.TCP(make([]byte, header.TCPMinimumSize)) + h.Encode(&header.TCPFields{ + SrcPort: src, + DstPort: dst, + SeqNum: 1, + AckNum: 2, + DataOffset: header.TCPMinimumSize, + Flags: 3, + WindowSize: 4, + Checksum: 0, + UrgentPointer: 5, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.UDP(make([]byte, header.UDPMinimumSize)) + h.Encode(&header.UDPFields{ + SrcPort: src, + DstPort: dst, + Length: 0, + Checksum: 0, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + origSrcPort := uint16(rand.Uint32()) + origDstPort := uint16(rand.Uint32()) + newPort := uint16(rand.Uint32()) + + t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range []struct { + name string + update func(header.ChecksummableTransport) + }{ + { + name: "Source port", + update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, + }, + { + name: "Destination port", + update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, + }, + } { + t.Run(subTest.name, func(t *testing.T) { + h, calcXSum := test.transportHdr(origSrcPort, origDstPort) + subTest.update(h) + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + got := ^h.Checksum() + h.SetChecksum(0) + + if want := calcXSum(pseudoHeaderXSum); got != want { + h, _ := test.transportHdr(origSrcPort, origDstPort) + t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) + } + }) + } + }) + } + }) + } +} + +func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { + const addressSize = 6 + + tests := []struct { + name string + transportHdr func() header.ChecksummableTransport + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + permanent := randomAddress(addressSize) + old := randomAddress(addressSize) + new := randomAddress(addressSize) + + t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, fullChecksum := range []bool{true, false} { + t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { + initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) + if fullChecksum { + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + initialXSum = ^initialXSum + } + + h := test.transportHdr() + h.SetChecksum(initialXSum) + h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) + + got := h.Checksum() + if fullChecksum { + got = ^got + } + if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { + t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go index 861cbbb70..3a41adfc4 100644 --- a/pkg/tcpip/header/interfaces.go +++ b/pkg/tcpip/header/interfaces.go @@ -53,6 +53,31 @@ type Transport interface { Payload() []byte } +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + // Network offers generic methods to query and/or update the fields of the // header of a network protocol buffer. type Network interface { @@ -90,3 +115,16 @@ type Network interface { // SetTOS sets the values of the "type of service" and "flow label" fields. SetTOS(t uint8, l uint32) } + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e9abbb709..dcc549c7b 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new)) + b.SetSourceAddress(new) +} + +// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new)) + b.SetDestinationAddress(new) +} + // padIPv4OptionsLength returns the total length for IPv4 options of length l // after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 8dabe3354..a75e51a28 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 b.SetChecksum(^checksum) } +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + // ParseSynOptions parses the options received in a SYN segment and returns the // relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index ae9d167ff..f69d53314 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpLength:], u.Length) binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) } + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 18e0d4374..782e74b24 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -405,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. + var newAddr tcpip.Address + var newPort uint16 + + updateSRCFields := false + switch hook { case Prerouting, Output: if conn.manip == manipDestination { switch dir { case dirOriginal: - tcpHeader.SetDestinationPort(conn.reply.srcPort) - netHeader.SetDestinationAddress(conn.reply.srcAddr) + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr case dirReply: - tcpHeader.SetSourcePort(conn.original.dstPort) - netHeader.SetSourceAddress(conn.original.dstAddr) + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + + updateSRCFields = true } pkt.NatDone = true } @@ -422,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if conn.manip == manipSource { switch dir { case dirOriginal: - tcpHeader.SetSourcePort(conn.reply.dstPort) - netHeader.SetSourceAddress(conn.reply.dstAddr) + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + + updateSRCFields = true case dirReply: - tcpHeader.SetDestinationPort(conn.original.srcPort) - netHeader.SetDestinationAddress(conn.original.srcAddr) + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr } pkt.NatDone = true } @@ -437,29 +446,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } + fullChecksum := false + updatePseudoHeader := false switch hook { case Prerouting, Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - tcpHeader.SetChecksum(xsum) + updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + fullChecksum = true + updatePseudoHeader = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } + rewritePacket( + netHeader, + tcpHeader, + updateSRCFields, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) // Update the state of tcb. conn.mu.Lock() diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 91e266de8..96cc899bb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -133,29 +133,23 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.Port) - // Calculate UDP checksum and set it. if hook == Output { - udpHeader.SetChecksum(0) - netHeader := pkt.Network() - netHeader.SetDestinationAddress(address) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + udpHeader, + false, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + rt.Port, + address, + ) + } else { + udpHeader.SetDestinationPort(rt.Port) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -214,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetChecksum(0) - udpHeader.SetSourcePort(st.Port) - netHeader := pkt.Network() - netHeader.SetSourceAddress(st.Addr) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + header.UDP(pkt.TransportHeader().View()), + true, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + st.Port, + st.Addr, + ) - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -252,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleAccept, 0 } + +func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { + if updateSRCFields { + if fullChecksum { + t.SetSourcePortWithChecksumUpdate(newPort) + } else { + t.SetSourcePort(newPort) + } + } else { + if fullChecksum { + t.SetDestinationPortWithChecksumUpdate(newPort) + } else { + t.SetDestinationPort(newPort) + } + } + + if updatePseudoHeader { + var oldAddr tcpip.Address + if updateSRCFields { + oldAddr = n.SourceAddress() + } else { + oldAddr = n.DestinationAddress() + } + + t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum) + } + + if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok { + if updateSRCFields { + checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr) + } else { + checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr) + } + } else if updateSRCFields { + n.SetSourceAddress(newAddr) + } else { + n.SetDestinationAddress(newAddr) + } +} |