diff options
Diffstat (limited to 'pkg/tcpip/header/checksum.go')
-rw-r--r-- | pkg/tcpip/header/checksum.go | 50 |
1 files changed, 39 insertions, 11 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 39a4d69be..9749c7f4d 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -23,11 +23,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -func calculateChecksum(buf []byte, initial uint32) uint16 { +func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { v := initial + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + l := len(buf) - if l&1 != 0 { + odd = l&1 != 0 + if odd { l-- v += uint32(buf[l]) << 8 } @@ -36,7 +42,7 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) } - return ChecksumCombine(uint16(v), uint16(v>>16)) + return ChecksumCombine(uint16(v), uint16(v>>16)), odd } // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the @@ -44,7 +50,8 @@ func calculateChecksum(buf []byte, initial uint32) uint16 { // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - return calculateChecksum(buf, uint32(initial)) + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s } // ChecksumVV calculates the checksum (as defined in RFC 1071) of the bytes in @@ -52,19 +59,40 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - var odd bool + return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) +} + +// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the +// bytes in the given VectorizedView. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { + odd := false sum := initial for _, v := range vv.Views() { if len(v) == 0 { continue } - s := uint32(sum) - if odd { - s += uint32(v[0]) - v = v[1:] + + if off >= len(v) { + off -= len(v) + continue + } + v = v[off:] + + l := len(v) + if l > size { + l = size + } + v = v[:l] + + sum, odd = calculateChecksum(v, odd, uint32(sum)) + + size -= len(v) + if size == 0 { + break } - odd = len(v)&1 != 0 - sum = calculateChecksum(v, s) + off = 0 } return sum } |