diff options
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r-- | pkg/tcpip/header/checksum.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv6.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/header/parse/parse.go | 30 |
4 files changed, 51 insertions, 54 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 14a4b2b44..6aa9acfa8 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -186,42 +186,29 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) + var c Checksumer + for _, v := range vv.Views() { + c.Add([]byte(v)) + } + return ChecksumCombine(initial, c.Checksum()) } -// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the -// bytes in the given VectorizedView. -// -// The initial checksum must have been computed on an even number of bytes. -func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { - odd := false - sum := initial - for _, v := range vv.Views() { - if len(v) == 0 { - continue - } - - if off >= len(v) { - off -= len(v) - continue - } - v = v[off:] - - l := len(v) - if l > size { - l = size - } - v = v[:l] - - sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) - - size -= len(v) - if size == 0 { - break - } - off = 0 +// Checksumer calculates checksum defined in RFC 1071. +type Checksumer struct { + sum uint16 + odd bool +} + +// Add adds b to checksum. +func (c *Checksumer) Add(b []byte) { + if len(b) > 0 { + c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum)) } - return sum +} + +// Checksum returns the latest checksum value. +func (c *Checksumer) Checksum() uint16 { + return c.sum } // ChecksumCombine combines the two uint16 to form their checksum. This is done diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index f840a4322..91c1c3cd2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -198,8 +197,8 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. -func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - xsum := ChecksumVV(vv, 0) +func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 { + xsum := payloadCsum // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index eca9750ab..668da623a 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -262,12 +261,22 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } +// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum. +type ICMPv6ChecksumParams struct { + Header ICMPv6 + Src tcpip.Address + Dst tcpip.Address + PayloadCsum uint16 + PayloadLen int +} + // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. -func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) +func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 { + h := params.Header - xsum = ChecksumVV(vv, xsum) + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen)) + xsum = ChecksumCombine(xsum, params.PayloadCsum) // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 2042f214a..ebb4b2c1d 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -41,7 +41,7 @@ func ARP(pkt *stack.PacketBuffer) bool { // // Returns true if the header was successfully parsed. func IPv4(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return false } @@ -62,27 +62,29 @@ func IPv4(pkt *stack.PacketBuffer) bool { ipHdr = header.IPv4(hdr) pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) return true } // IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network // header with the IPv6 header. func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return 0, 0, 0, false, false } ipHdr := header.IPv6(hdr) - // dataClone consists of: + // Create a VV to parse the packet. We don't plan to modify anything here. + // dataVV consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + dataVV := buffer.NewVectorisedView(0, views[:0]) + dataVV.AppendViews(pkt.Data().Views()) + dataVV.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataVV) // Iterate over the IPv6 extensions to find their length. var nextHdr tcpip.TransportProtocolNumber @@ -98,7 +100,7 @@ traverseExtensions: // If we exhaust the extension list, the entire packet is the IPv6 header // and (possibly) extensions. if done { - extensionsSize = dataClone.Size() + extensionsSize = dataVV.Size() break } @@ -110,12 +112,12 @@ traverseExtensions: fragMore = extHdr.More() } rawPayload := it.AsRawHeader(true /* consume */) - extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + extensionsSize = dataVV.Size() - rawPayload.Buf.Size() break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() + extensionsSize = dataVV.Size() - extHdr.Buf.Size() nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) break traverseExtensions @@ -127,10 +129,10 @@ traverseExtensions: // Put the IPv6 header with extensions in pkt.NetworkHeader(). hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size())) } ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.Data().CapLength(int(ipHdr.PayloadLength())) pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, fragID, fragOffset, fragMore, true @@ -153,13 +155,13 @@ func UDP(pkt *stack.PacketBuffer) bool { func TCP(pkt *stack.PacketBuffer) bool { // TCP header is variable length, peek at it first. hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) + hdr, ok := pkt.Data().PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() { // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of // packets. hdrLen = offset |