summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/header/checksum.go17
-rw-r--r--pkg/tcpip/header/tcp.go15
-rw-r--r--pkg/tcpip/header/udp.go14
-rw-r--r--pkg/tcpip/stack/route.go4
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go6
8 files changed, 30 insertions, 40 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 2c2fcbd9b..2e8c65fac 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -17,6 +17,8 @@
package header
import (
+ "encoding/binary"
+
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
)
@@ -76,12 +78,17 @@ func ChecksumCombine(a, b uint16) uint16 {
return uint16(v + v>>16)
}
-// PseudoHeaderChecksum calculates the pseudo-header checksum for the
-// given destination protocol and network address, ignoring the length
-// field. Pseudo-headers are needed by transport layers when calculating
-// their own checksum.
-func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address) uint16 {
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
+// destination protocol and network address. Pseudo-headers are needed by
+// transport layers when calculating their own checksum.
+func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address, totalLen uint16) uint16 {
xsum := Checksum([]byte(srcAddr), 0)
xsum = Checksum([]byte(dstAddr), xsum)
+
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ xsum = Checksum(tmp, xsum)
+
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 207046b36..d37d624fe 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -231,19 +231,12 @@ func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[tcpChecksum:], checksum)
}
-// CalculateChecksum calculates the checksum of the tcp segment given
-// the totalLen and partialChecksum(descriptions below)
-// totalLen is the total length of the segment
+// CalculateChecksum calculates the checksum of the tcp segment.
// partialChecksum is the checksum of the network-layer pseudo-header
-// (excluding the total length) and the checksum of the segment data.
-func (b TCP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
- // Add the length portion of the checksum to the pseudo-checksum.
- tmp := make([]byte, 2)
- binary.BigEndian.PutUint16(tmp, totalLen)
- checksum := Checksum(tmp, partialChecksum)
-
+// and the checksum of the segment data.
+func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
- return Checksum(b[:b.DataOffset()], checksum)
+ return Checksum(b[:b.DataOffset()], partialChecksum)
}
// Options returns a slice that holds the unparsed TCP options in the segment.
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 31c8ef456..e8c860436 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -94,17 +94,11 @@ func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
-// CalculateChecksum calculates the checksum of the udp packet, given the total
-// length of the packet and the checksum of the network-layer pseudo-header
-// (excluding the total length) and the checksum of the payload.
-func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
- // Add the length portion of the checksum to the pseudo-checksum.
- tmp := make([]byte, 2)
- binary.BigEndian.PutUint16(tmp, totalLen)
- checksum := Checksum(tmp, partialChecksum)
-
+// CalculateChecksum calculates the checksum of the udp packet, given the
+// checksum of the network-layer pseudo-header and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
- return Checksum(b[:UDPMinimumSize], checksum)
+ return Checksum(b[:UDPMinimumSize], partialChecksum)
}
// Encode encodes all the fields of the udp header.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 3f2264d16..ee860eafe 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -88,8 +88,8 @@ func (r *Route) Stats() tcpip.Stats {
// PseudoHeaderChecksum forwards the call to the network endpoint's
// implementation.
-func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 {
- return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
+func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, totalLen uint16) uint16 {
+ return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
}
// Capabilities returns the link-layer capabilities of the route.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4d352b23c..c4353718e 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -589,10 +589,10 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
length := uint16(hdr.UsedLength() + data.Size())
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
xsum = header.ChecksumVV(data, xsum)
- tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
r.Stats().TCP.SegmentsSent.Increment()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index fb4ae4a1b..aa2a73829 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -332,9 +332,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum)
// Calculate the TCP checksum and set it.
- length := uint16(header.TCPMinimumSize + len(h.TCPOpts) + len(payload))
xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum, length))
+ t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
return buf.ToVectorisedView()
@@ -487,9 +486,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum)
// Calculate the TCP checksum and set it.
- length := uint16(header.TCPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
- t.SetChecksum(^t.CalculateChecksum(xsum, length))
+ t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 82d7f80a3..b68ed8561 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -641,11 +641,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
- udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
+ udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
// Track count of packets sent.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 884a76b04..0d5871615 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -200,9 +200,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers) {
xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
// Calculate the UDP checksum and set it.
- length := uint16(header.UDPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum, length))
+ u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
@@ -239,9 +238,8 @@ func (c *testContext) sendPacket(payload []byte, h *headers) {
xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
// Calculate the UDP checksum and set it.
- length := uint16(header.UDPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
- u.SetChecksum(^u.CalculateChecksum(xsum, length))
+ u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())