summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/header/checksum.go62
-rw-r--r--pkg/tcpip/header/interfaces.go38
-rw-r--r--pkg/tcpip/header/ipv4.go12
-rw-r--r--pkg/tcpip/header/tcp.go29
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go5
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go145
-rw-r--r--pkg/tcpip/stack/conntrack.go51
-rw-r--r--pkg/tcpip/stack/iptables_targets.go97
9 files changed, 338 insertions, 130 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 6aa9acfa8..e2c85e220 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,6 +18,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
+
+// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
+// checksum.
+//
+// The value MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
+}
+
+// checksumUpdate2ByteAlignedAddress updates an address in a calculated
+// checksum.
+//
+// The addresses must have the same length and must contain an even number
+// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
+ const uint16Bytes = 2
+
+ if len(old) != len(new) {
+ panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
+ }
+
+ if len(old)%uint16Bytes != 0 {
+ panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
+ }
+
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ for len(old) != 0 {
+ // Convert the 2 byte sequences to uint16 values then apply the increment
+ // update.
+ xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
+ old = old[uint16Bytes:]
+ new = new[uint16Bytes:]
+ }
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
index 861cbbb70..3a41adfc4 100644
--- a/pkg/tcpip/header/interfaces.go
+++ b/pkg/tcpip/header/interfaces.go
@@ -53,6 +53,31 @@ type Transport interface {
Payload() []byte
}
+// ChecksummableTransport is a Transport that supports checksumming.
+type ChecksummableTransport interface {
+ Transport
+
+ // SetSourcePortWithChecksumUpdate sets the source port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetSourcePortWithChecksumUpdate(port uint16)
+
+ // SetDestinationPortWithChecksumUpdate sets the destination port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetDestinationPortWithChecksumUpdate(port uint16)
+
+ // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
+ // updated address in the pseudo header.
+ //
+ // If fullChecksum is true, the receiver's checksum field is assumed to hold a
+ // fully calculated checksum. Otherwise, it is assumed to hold a partially
+ // calculated checksum which only reflects the pseudo header.
+ UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
+}
+
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
@@ -90,3 +115,16 @@ type Network interface {
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
+
+// ChecksummableNetwork is a Network that supports checksumming.
+type ChecksummableNetwork interface {
+ Network
+
+ // SetSourceAddressAndChecksum sets the source address and updates the
+ // checksum to reflect the new address.
+ SetSourceAddressWithChecksumUpdate(tcpip.Address)
+
+ // SetDestinationAddressAndChecksum sets the destination address and
+ // updates the checksum to reflect the new address.
+ SetDestinationAddressWithChecksumUpdate(tcpip.Address)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e9abbb709..dcc549c7b 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new))
+ b.SetSourceAddress(new)
+}
+
+// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new))
+ b.SetDestinationAddress(new)
+}
+
// padIPv4OptionsLength returns the total length for IPv4 options of length l
// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 8dabe3354..a75e51a28 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
b.SetChecksum(^checksum)
}
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
+
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index ae9d167ff..f69d53314 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
+
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f5693defe..b1aec5312 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -344,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
defer e.mu.Unlock()
- e.mu.ndp.invalidateDefaultRouter(rtr)
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
}
// SetNDPConfigurations implements NDPEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index c44e4ac4e..9cd283eba 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -78,13 +78,13 @@ const (
// we cannot have a negative delay.
minimumMaxRtrSolicitationDelay = 0
- // MaxDiscoveredDefaultRouters is the maximum number of discovered
- // default routers. The stack should stop discovering new routers after
- // discovering MaxDiscoveredDefaultRouters routers.
+ // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link
+ // routes. The stack should stop discovering new off-link routes after
+ // this limit is reached.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- MaxDiscoveredDefaultRouters = 10
+ MaxDiscoveredOffLinkRoutes = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
@@ -448,6 +448,11 @@ type timer struct {
timer tcpip.Timer
}
+type offLinkRoute struct {
+ dest tcpip.Subnet
+ router tcpip.Address
+}
+
// ndpState is the per-Interface NDP state.
type ndpState struct {
// Do not allow overwriting this state.
@@ -462,8 +467,8 @@ type ndpState struct {
// The DAD timers to send the next NS message, or resolve the address.
dad ip.DAD
- // The default routers discovered through Router Advertisements.
- defaultRouters map[tcpip.Address]defaultRouterState
+ // The off-link routes discovered through Router Advertisements.
+ offLinkRoutes map[offLinkRoute]offLinkRouteState
// rtrSolicitTimer is the timer used to send the next router solicitation
// message.
@@ -491,12 +496,12 @@ type ndpState struct {
temporaryAddressDesyncFactor time.Duration
}
-// defaultRouterState holds data associated with a default router discovered by
+// offLinkRouteState holds data associated with an off-link route discovered by
// a Router Advertisement (RA).
-type defaultRouterState struct {
+type offLinkRouteState struct {
prf header.NDPRoutePreference
- // Job to invalidate the default router.
+ // Job to invalidate the route.
//
// Must not be nil.
invalidationJob *tcpip.Job
@@ -727,38 +732,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
prf = header.MediumRoutePreference
}
- rtr, ok := ndp.defaultRouters[ip]
- rl := ra.RouterLifetime()
- switch {
- case !ok && rl != 0:
- // This is a new default router we are discovering.
- //
- // Only remember it if we currently know about less than
- // MaxDiscoveredDefaultRouters routers.
- if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
- ndp.rememberDefaultRouter(ip, rl, prf)
- }
-
- case ok && rl != 0:
- // This is an already discovered default router. Update
- // the invalidation job.
- rtr.invalidationJob.Cancel()
- rtr.invalidationJob.Schedule(rl)
-
- if prf != rtr.prf {
- rtr.prf = prf
-
- // Inform the integrator about router preference updates.
- ndp.ep.protocol.options.NDPDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf)
- }
-
- ndp.defaultRouters[ip] = rtr
-
- case ok && rl == 0:
- // We know about the router but it is no longer to be
- // used as a default router so invalidate it.
- ndp.invalidateDefaultRouter(ip)
- }
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf)
}
// TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
@@ -816,52 +792,75 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
}
-// invalidateDefaultRouter invalidates a discovered default router.
+// invalidateOffLinkRoute invalidates a discovered off-link route.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
- rtr, ok := ndp.defaultRouters[ip]
-
- // Is the router still discovered?
+func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) {
+ state, ok := ndp.offLinkRoutes[route]
if !ok {
- // ...Nope, do nothing further.
return
}
- rtr.invalidationJob.Cancel()
- delete(ndp.defaultRouters, ip)
+ state.invalidationJob.Cancel()
+ delete(ndp.offLinkRoutes, route)
- // Let the integrator know a discovered default router is invalidated.
+ // Let the integrator know a discovered off-link route is invalidated.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip)
+ ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router)
}
}
-// rememberDefaultRouter remembers a newly discovered default router with IPv6
-// link-local address ip with lifetime rl.
-//
-// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
+// handleOffLinkRouteDiscovery handles the discovery of an off-link route.
//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration, prf header.NDPRoutePreference) {
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) {
ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
- // Inform the integrator when we discovered a default router.
- ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf)
+ state, ok := ndp.offLinkRoutes[route]
+ switch {
+ case !ok && lifetime != 0:
+ // This is a new route we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOffLinkRoutes routers.
+ if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes {
+ // Inform the integrator when we discovered an off-link route.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+
+ state := offLinkRouteState{
+ prf: prf,
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOffLinkRoute(route)
+ }),
+ }
- state := defaultRouterState{
- prf: prf,
- invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- ndp.invalidateDefaultRouter(ip)
- }),
- }
+ state.invalidationJob.Schedule(lifetime)
- state.invalidationJob.Schedule(rl)
+ ndp.offLinkRoutes[route] = state
+ }
+
+ case ok && lifetime != 0:
+ // This is an already discovered off-link route. Update the lifetime.
+ state.invalidationJob.Cancel()
+ state.invalidationJob.Schedule(lifetime)
- ndp.defaultRouters[ip] = state
+ if prf != state.prf {
+ state.prf = prf
+
+ // Inform the integrator about route preference updates.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+ }
+
+ ndp.offLinkRoutes[route] = state
+
+ case ok && lifetime == 0:
+ // The already discovered off-link route is no longer considered valid so we
+ // invalidate it immediately.
+ ndp.invalidateOffLinkRoute(route)
+ }
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -1677,12 +1676,12 @@ func (ndp *ndpState) cleanupState() {
panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
- for router := range ndp.defaultRouters {
- ndp.invalidateDefaultRouter(router)
+ for route := range ndp.offLinkRoutes {
+ ndp.invalidateOffLinkRoute(route)
}
- if got := len(ndp.defaultRouters); got != 0 {
- panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ if got := len(ndp.offLinkRoutes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got))
}
ndp.dhcpv6Configuration = 0
@@ -1845,14 +1844,14 @@ func (ndp *ndpState) stopSolicitingRouters() {
}
func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
- if ndp.defaultRouters != nil {
+ if ndp.offLinkRoutes != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
- ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 18e0d4374..782e74b24 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -405,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
+ var newAddr tcpip.Address
+ var newPort uint16
+
+ updateSRCFields := false
+
switch hook {
case Prerouting, Output:
if conn.manip == manipDestination {
switch dir {
case dirOriginal:
- tcpHeader.SetDestinationPort(conn.reply.srcPort)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
case dirReply:
- tcpHeader.SetSourcePort(conn.original.dstPort)
- netHeader.SetSourceAddress(conn.original.dstAddr)
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+
+ updateSRCFields = true
}
pkt.NatDone = true
}
@@ -422,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if conn.manip == manipSource {
switch dir {
case dirOriginal:
- tcpHeader.SetSourcePort(conn.reply.dstPort)
- netHeader.SetSourceAddress(conn.reply.dstAddr)
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+
+ updateSRCFields = true
case dirReply:
- tcpHeader.SetDestinationPort(conn.original.srcPort)
- netHeader.SetDestinationAddress(conn.original.srcAddr)
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
}
pkt.NatDone = true
}
@@ -437,29 +446,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
+ fullChecksum := false
+ updatePseudoHeader := false
switch hook {
case Prerouting, Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
+ updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ fullChecksum = true
+ updatePseudoHeader = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
+ rewritePacket(
+ netHeader,
+ tcpHeader,
+ updateSRCFields,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
// Update the state of tcb.
conn.mu.Lock()
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 91e266de8..96cc899bb 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -133,29 +133,23 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.Port)
- // Calculate UDP checksum and set it.
if hook == Output {
- udpHeader.SetChecksum(0)
- netHeader := pkt.Network()
- netHeader.SetDestinationAddress(address)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ udpHeader,
+ false, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ rt.Port,
+ address,
+ )
+ } else {
+ udpHeader.SetDestinationPort(rt.Port)
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -214,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetChecksum(0)
- udpHeader.SetSourcePort(st.Port)
- netHeader := pkt.Network()
- netHeader.SetSourceAddress(st.Addr)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ header.UDP(pkt.TransportHeader().View()),
+ true, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ st.Port,
+ st.Addr,
+ )
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -252,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleAccept, 0
}
+
+func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
+ if updateSRCFields {
+ if fullChecksum {
+ t.SetSourcePortWithChecksumUpdate(newPort)
+ } else {
+ t.SetSourcePort(newPort)
+ }
+ } else {
+ if fullChecksum {
+ t.SetDestinationPortWithChecksumUpdate(newPort)
+ } else {
+ t.SetDestinationPort(newPort)
+ }
+ }
+
+ if updatePseudoHeader {
+ var oldAddr tcpip.Address
+ if updateSRCFields {
+ oldAddr = n.SourceAddress()
+ } else {
+ oldAddr = n.DestinationAddress()
+ }
+
+ t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
+ }
+
+ if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
+ if updateSRCFields {
+ checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
+ } else {
+ checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
+ }
+ } else if updateSRCFields {
+ n.SetSourceAddress(newAddr)
+ } else {
+ n.SetDestinationAddress(newAddr)
+ }
+}