diff options
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r-- | pkg/tcpip/header/checksum.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 113 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv4.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/header/icmpv6.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/header/parse/parse.go | 30 |
5 files changed, 104 insertions, 114 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 14a4b2b44..6aa9acfa8 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -186,42 +186,29 @@ func Checksum(buf []byte, initial uint16) uint16 { // // The initial checksum must have been computed on an even number of bytes. func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 { - return ChecksumVVWithOffset(vv, initial, 0, vv.Size()) + var c Checksumer + for _, v := range vv.Views() { + c.Add([]byte(v)) + } + return ChecksumCombine(initial, c.Checksum()) } -// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the -// bytes in the given VectorizedView. -// -// The initial checksum must have been computed on an even number of bytes. -func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 { - odd := false - sum := initial - for _, v := range vv.Views() { - if len(v) == 0 { - continue - } - - if off >= len(v) { - off -= len(v) - continue - } - v = v[off:] - - l := len(v) - if l > size { - l = size - } - v = v[:l] - - sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) - - size -= len(v) - if size == 0 { - break - } - off = 0 +// Checksumer calculates checksum defined in RFC 1071. +type Checksumer struct { + sum uint16 + odd bool +} + +// Add adds b to checksum. +func (c *Checksumer) Add(b []byte) { + if len(b) > 0 { + c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum)) } - return sum +} + +// Checksum returns the latest checksum value. +func (c *Checksumer) Checksum() uint16 { + return c.sum } // ChecksumCombine combines the two uint16 to form their checksum. This is done diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 5ab20ee86..d267dabd0 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,7 @@ package header_test import ( + "bytes" "fmt" "math/rand" "sync" @@ -26,86 +27,72 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -func TestChecksumVVWithOffset(t *testing.T) { +func TestChecksumer(t *testing.T) { testCases := []struct { - name string - vv buffer.VectorisedView - off, size int - initial uint16 - want uint16 + name string + data [][]byte + want uint16 }{ { name: "empty", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 0, want: 0, }, { - name: "OneView", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - }), - off: 0, - size: 5, + name: "OneOddView", + data: [][]byte{ + []byte{1, 9, 0, 5, 4}, + }, want: 1294, }, { - name: "TwoViews", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 0, - size: 11, + name: "TwoOddViews", + data: [][]byte{ + []byte{1, 9, 0, 5, 4}, + []byte{4, 3, 7, 1, 2, 123}, + }, want: 33819, }, { - name: "TwoViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 1, - size: 11, - want: 33819, + name: "OneEvenView", + data: [][]byte{ + []byte{1, 9, 0, 5}, + }, + want: 270, }, { - name: "ThreeViewsWithOffset", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}), - }), - off: 7, - size: 11, - want: 33819, + name: "TwoEvenViews", + data: [][]byte{ + buffer.NewViewFromBytes([]byte{98, 1, 9, 0}), + buffer.NewViewFromBytes([]byte{9, 0, 5, 4}), + }, + want: 30981, }, { - name: "ThreeViewsWithInitial", - vv: buffer.NewVectorisedView(0, []buffer.View{ - buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}), - buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}), - buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}), - }), - initial: 77, - off: 7, - size: 11, - want: 33896, + name: "ThreeViews", + data: [][]byte{ + []byte{77, 11, 33, 0, 55, 44}, + []byte{98, 1, 9, 0, 5, 4}, + []byte{4, 3, 7, 1, 2, 123, 99}, + }, + want: 34236, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want { - t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want) + var all bytes.Buffer + var c header.Checksumer + for _, b := range tc.data { + c.Add(b) + // Append to the buffer. We will check the checksum as a whole later. + if _, err := all.Write(b); err != nil { + t.Fatalf("all.Write(b) = _, %s; want _, nil", err) + } + } + if got, want := c.Checksum(), tc.want; got != want { + t.Errorf("c.Checksum() = %d, want %d", got, want) } - v := tc.vv.ToView() - v.TrimFront(tc.off) - v.CapLength(tc.size) - if got, want := header.Checksum(v, tc.initial), tc.want; got != want { - t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want) + if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want { + t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want) } }) } @@ -228,7 +215,7 @@ func TestICMPv4Checksum(t *testing.T) { h.SetChecksum(want) testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv4Checksum(h, vv) + return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0)) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } @@ -260,6 +247,12 @@ func TestICMPv6Checksum(t *testing.T) { h.SetChecksum(want) testICMPChecksum(t, h.Checksum, func() uint16 { - return header.ICMPv6Checksum(h, src, dst, vv) + return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: h, + Src: src, + Dst: dst, + PayloadCsum: header.ChecksumVV(vv, 0), + PayloadLen: vv.Size(), + }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index f840a4322..91c1c3cd2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. @@ -198,8 +197,8 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. -func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - xsum := ChecksumVV(vv, 0) +func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 { + xsum := payloadCsum // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index eca9750ab..668da623a 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -18,7 +18,6 @@ import ( "encoding/binary" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. @@ -262,12 +261,22 @@ func (b ICMPv6) Payload() []byte { return b[ICMPv6PayloadOffset:] } +// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum. +type ICMPv6ChecksumParams struct { + Header ICMPv6 + Src tcpip.Address + Dst tcpip.Address + PayloadCsum uint16 + PayloadLen int +} + // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. -func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) +func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 { + h := params.Header - xsum = ChecksumVV(vv, xsum) + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen)) + xsum = ChecksumCombine(xsum, params.PayloadCsum) // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. xsum = Checksum(h[:2], xsum) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 2042f214a..ebb4b2c1d 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -41,7 +41,7 @@ func ARP(pkt *stack.PacketBuffer) bool { // // Returns true if the header was successfully parsed. func IPv4(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize) if !ok { return false } @@ -62,27 +62,29 @@ func IPv4(pkt *stack.PacketBuffer) bool { ipHdr = header.IPv4(hdr) pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr)) return true } // IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network // header with the IPv6 header. func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize) if !ok { return 0, 0, 0, false, false } ipHdr := header.IPv6(hdr) - // dataClone consists of: + // Create a VV to parse the packet. We don't plan to modify anything here. + // dataVV consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + dataVV := buffer.NewVectorisedView(0, views[:0]) + dataVV.AppendViews(pkt.Data().Views()) + dataVV.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataVV) // Iterate over the IPv6 extensions to find their length. var nextHdr tcpip.TransportProtocolNumber @@ -98,7 +100,7 @@ traverseExtensions: // If we exhaust the extension list, the entire packet is the IPv6 header // and (possibly) extensions. if done { - extensionsSize = dataClone.Size() + extensionsSize = dataVV.Size() break } @@ -110,12 +112,12 @@ traverseExtensions: fragMore = extHdr.More() } rawPayload := it.AsRawHeader(true /* consume */) - extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + extensionsSize = dataVV.Size() - rawPayload.Buf.Size() break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() + extensionsSize = dataVV.Size() - extHdr.Buf.Size() nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) break traverseExtensions @@ -127,10 +129,10 @@ traverseExtensions: // Put the IPv6 header with extensions in pkt.NetworkHeader(). hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size())) } ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.Data().CapLength(int(ipHdr.PayloadLength())) pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, fragID, fragOffset, fragMore, true @@ -153,13 +155,13 @@ func UDP(pkt *stack.PacketBuffer) bool { func TCP(pkt *stack.PacketBuffer) bool { // TCP header is variable length, peek at it first. hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) + hdr, ok := pkt.Data().PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() { // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of // packets. hdrLen = offset |