summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/header
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r--pkg/tcpip/header/checksum.go53
-rw-r--r--pkg/tcpip/header/checksum_test.go113
-rw-r--r--pkg/tcpip/header/icmpv4.go5
-rw-r--r--pkg/tcpip/header/icmpv6.go17
-rw-r--r--pkg/tcpip/header/parse/parse.go30
5 files changed, 104 insertions, 114 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 14a4b2b44..6aa9acfa8 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -186,42 +186,29 @@ func Checksum(buf []byte, initial uint16) uint16 {
//
// The initial checksum must have been computed on an even number of bytes.
func ChecksumVV(vv buffer.VectorisedView, initial uint16) uint16 {
- return ChecksumVVWithOffset(vv, initial, 0, vv.Size())
+ var c Checksumer
+ for _, v := range vv.Views() {
+ c.Add([]byte(v))
+ }
+ return ChecksumCombine(initial, c.Checksum())
}
-// ChecksumVVWithOffset calculates the checksum (as defined in RFC 1071) of the
-// bytes in the given VectorizedView.
-//
-// The initial checksum must have been computed on an even number of bytes.
-func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, size int) uint16 {
- odd := false
- sum := initial
- for _, v := range vv.Views() {
- if len(v) == 0 {
- continue
- }
-
- if off >= len(v) {
- off -= len(v)
- continue
- }
- v = v[off:]
-
- l := len(v)
- if l > size {
- l = size
- }
- v = v[:l]
-
- sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
-
- size -= len(v)
- if size == 0 {
- break
- }
- off = 0
+// Checksumer calculates checksum defined in RFC 1071.
+type Checksumer struct {
+ sum uint16
+ odd bool
+}
+
+// Add adds b to checksum.
+func (c *Checksumer) Add(b []byte) {
+ if len(b) > 0 {
+ c.sum, c.odd = unrolledCalculateChecksum(b, c.odd, uint32(c.sum))
}
- return sum
+}
+
+// Checksum returns the latest checksum value.
+func (c *Checksumer) Checksum() uint16 {
+ return c.sum
}
// ChecksumCombine combines the two uint16 to form their checksum. This is done
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 5ab20ee86..d267dabd0 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -17,6 +17,7 @@
package header_test
import (
+ "bytes"
"fmt"
"math/rand"
"sync"
@@ -26,86 +27,72 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-func TestChecksumVVWithOffset(t *testing.T) {
+func TestChecksumer(t *testing.T) {
testCases := []struct {
- name string
- vv buffer.VectorisedView
- off, size int
- initial uint16
- want uint16
+ name string
+ data [][]byte
+ want uint16
}{
{
name: "empty",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- }),
- off: 0,
- size: 0,
want: 0,
},
{
- name: "OneView",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- }),
- off: 0,
- size: 5,
+ name: "OneOddView",
+ data: [][]byte{
+ []byte{1, 9, 0, 5, 4},
+ },
want: 1294,
},
{
- name: "TwoViews",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 0,
- size: 11,
+ name: "TwoOddViews",
+ data: [][]byte{
+ []byte{1, 9, 0, 5, 4},
+ []byte{4, 3, 7, 1, 2, 123},
+ },
want: 33819,
},
{
- name: "TwoViewsWithOffset",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 1,
- size: 11,
- want: 33819,
+ name: "OneEvenView",
+ data: [][]byte{
+ []byte{1, 9, 0, 5},
+ },
+ want: 270,
},
{
- name: "ThreeViewsWithOffset",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123}),
- }),
- off: 7,
- size: 11,
- want: 33819,
+ name: "TwoEvenViews",
+ data: [][]byte{
+ buffer.NewViewFromBytes([]byte{98, 1, 9, 0}),
+ buffer.NewViewFromBytes([]byte{9, 0, 5, 4}),
+ },
+ want: 30981,
},
{
- name: "ThreeViewsWithInitial",
- vv: buffer.NewVectorisedView(0, []buffer.View{
- buffer.NewViewFromBytes([]byte{77, 11, 33, 0, 55, 44}),
- buffer.NewViewFromBytes([]byte{98, 1, 9, 0, 5, 4}),
- buffer.NewViewFromBytes([]byte{4, 3, 7, 1, 2, 123, 99}),
- }),
- initial: 77,
- off: 7,
- size: 11,
- want: 33896,
+ name: "ThreeViews",
+ data: [][]byte{
+ []byte{77, 11, 33, 0, 55, 44},
+ []byte{98, 1, 9, 0, 5, 4},
+ []byte{4, 3, 7, 1, 2, 123, 99},
+ },
+ want: 34236,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- if got, want := header.ChecksumVVWithOffset(tc.vv, tc.initial, tc.off, tc.size), tc.want; got != want {
- t.Errorf("header.ChecksumVVWithOffset(%v) = %v, want: %v", tc, got, tc.want)
+ var all bytes.Buffer
+ var c header.Checksumer
+ for _, b := range tc.data {
+ c.Add(b)
+ // Append to the buffer. We will check the checksum as a whole later.
+ if _, err := all.Write(b); err != nil {
+ t.Fatalf("all.Write(b) = _, %s; want _, nil", err)
+ }
+ }
+ if got, want := c.Checksum(), tc.want; got != want {
+ t.Errorf("c.Checksum() = %d, want %d", got, want)
}
- v := tc.vv.ToView()
- v.TrimFront(tc.off)
- v.CapLength(tc.size)
- if got, want := header.Checksum(v, tc.initial), tc.want; got != want {
- t.Errorf("header.Checksum(%v) = %v, want: %v", tc, got, tc.want)
+ if got, want := header.Checksum(all.Bytes(), 0 /* initial */), tc.want; got != want {
+ t.Errorf("Checksum(flatten tc.data) = %d, want %d", got, want)
}
})
}
@@ -228,7 +215,7 @@ func TestICMPv4Checksum(t *testing.T) {
h.SetChecksum(want)
testICMPChecksum(t, h.Checksum, func() uint16 {
- return header.ICMPv4Checksum(h, vv)
+ return header.ICMPv4Checksum(h, header.ChecksumVV(vv, 0))
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
@@ -260,6 +247,12 @@ func TestICMPv6Checksum(t *testing.T) {
h.SetChecksum(want)
testICMPChecksum(t, h.Checksum, func() uint16 {
- return header.ICMPv6Checksum(h, src, dst, vv)
+ return header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
+ Header: h,
+ Src: src,
+ Dst: dst,
+ PayloadCsum: header.ChecksumVV(vv, 0),
+ PayloadLen: vv.Size(),
+ })
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index f840a4322..91c1c3cd2 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -18,7 +18,6 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
@@ -198,8 +197,8 @@ func (b ICMPv4) SetSequence(sequence uint16) {
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
-func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
- xsum := ChecksumVV(vv, 0)
+func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 {
+ xsum := payloadCsum
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = Checksum(h[:2], xsum)
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index eca9750ab..668da623a 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -18,7 +18,6 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
@@ -262,12 +261,22 @@ func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
+// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
+type ICMPv6ChecksumParams struct {
+ Header ICMPv6
+ Src tcpip.Address
+ Dst tcpip.Address
+ PayloadCsum uint16
+ PayloadLen int
+}
+
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
-func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size()))
+func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
+ h := params.Header
- xsum = ChecksumVV(vv, xsum)
+ xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen))
+ xsum = ChecksumCombine(xsum, params.PayloadCsum)
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = Checksum(h[:2], xsum)
diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go
index 2042f214a..ebb4b2c1d 100644
--- a/pkg/tcpip/header/parse/parse.go
+++ b/pkg/tcpip/header/parse/parse.go
@@ -41,7 +41,7 @@ func ARP(pkt *stack.PacketBuffer) bool {
//
// Returns true if the header was successfully parsed.
func IPv4(pkt *stack.PacketBuffer) bool {
- hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
if !ok {
return false
}
@@ -62,27 +62,29 @@ func IPv4(pkt *stack.PacketBuffer) bool {
ipHdr = header.IPv4(hdr)
pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ pkt.Data().CapLength(int(ipHdr.TotalLength()) - len(hdr))
return true
}
// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network
// header with the IPv6 header.
func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) {
- hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ hdr, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
if !ok {
return 0, 0, 0, false, false
}
ipHdr := header.IPv6(hdr)
- // dataClone consists of:
+ // Create a VV to parse the packet. We don't plan to modify anything here.
+ // dataVV consists of:
// - Any IPv6 header bytes after the first 40 (i.e. extensions).
// - The transport header, if present.
// - Any other payload data.
views := [8]buffer.View{}
- dataClone := pkt.Data.Clone(views[:])
- dataClone.TrimFront(header.IPv6MinimumSize)
- it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+ dataVV := buffer.NewVectorisedView(0, views[:0])
+ dataVV.AppendViews(pkt.Data().Views())
+ dataVV.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataVV)
// Iterate over the IPv6 extensions to find their length.
var nextHdr tcpip.TransportProtocolNumber
@@ -98,7 +100,7 @@ traverseExtensions:
// If we exhaust the extension list, the entire packet is the IPv6 header
// and (possibly) extensions.
if done {
- extensionsSize = dataClone.Size()
+ extensionsSize = dataVV.Size()
break
}
@@ -110,12 +112,12 @@ traverseExtensions:
fragMore = extHdr.More()
}
rawPayload := it.AsRawHeader(true /* consume */)
- extensionsSize = dataClone.Size() - rawPayload.Buf.Size()
+ extensionsSize = dataVV.Size() - rawPayload.Buf.Size()
break traverseExtensions
case header.IPv6RawPayloadHeader:
// We've found the payload after any extensions.
- extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ extensionsSize = dataVV.Size() - extHdr.Buf.Size()
nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
break traverseExtensions
@@ -127,10 +129,10 @@ traverseExtensions:
// Put the IPv6 header with extensions in pkt.NetworkHeader().
hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
if !ok {
- panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data().Size()))
}
ipHdr = header.IPv6(hdr)
- pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+ pkt.Data().CapLength(int(ipHdr.PayloadLength()))
pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
return nextHdr, fragID, fragOffset, fragMore, true
@@ -153,13 +155,13 @@ func UDP(pkt *stack.PacketBuffer) bool {
func TCP(pkt *stack.PacketBuffer) bool {
// TCP header is variable length, peek at it first.
hdrLen := header.TCPMinimumSize
- hdr, ok := pkt.Data.PullUp(hdrLen)
+ hdr, ok := pkt.Data().PullUp(hdrLen)
if !ok {
return false
}
// If the header has options, pull those up as well.
- if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data().Size() {
// TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
// packets.
hdrLen = offset