summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/header
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r--pkg/tcpip/header/checksum.go62
-rw-r--r--pkg/tcpip/header/checksum_test.go203
-rw-r--r--pkg/tcpip/header/interfaces.go38
-rw-r--r--pkg/tcpip/header/ipv4.go12
-rw-r--r--pkg/tcpip/header/ndp_options.go145
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go19
-rw-r--r--pkg/tcpip/header/ndp_test.go248
-rw-r--r--pkg/tcpip/header/tcp.go29
-rw-r--r--pkg/tcpip/header/udp.go29
9 files changed, 785 insertions, 0 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 6aa9acfa8..e2c85e220 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,6 +18,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
+
+// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
+// checksum.
+//
+// The value MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
+}
+
+// checksumUpdate2ByteAlignedAddress updates an address in a calculated
+// checksum.
+//
+// The addresses must have the same length and must contain an even number
+// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
+ const uint16Bytes = 2
+
+ if len(old) != len(new) {
+ panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
+ }
+
+ if len(old)%uint16Bytes != 0 {
+ panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
+ }
+
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ for len(old) != 0 {
+ // Convert the 2 byte sequences to uint16 values then apply the increment
+ // update.
+ xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
+ old = old[uint16Bytes:]
+ new = new[uint16Bytes:]
+ }
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index d267dabd0..3445511f4 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -23,6 +23,7 @@ import (
"sync"
"testing"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) {
})
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
+
+func randomAddress(size int) tcpip.Address {
+ s := make([]byte, size)
+ for i := 0; i < size; i++ {
+ s[i] = byte(rand.Uint32())
+ }
+ return tcpip.Address(s)
+}
+
+func TestChecksummableNetworkUpdateAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ update func(header.IPv4, tcpip.Address)
+ }{
+ {
+ name: "SetSourceAddressWithChecksumUpdate",
+ update: header.IPv4.SetSourceAddressWithChecksumUpdate,
+ },
+ {
+ name: "SetDestinationAddressWithChecksumUpdate",
+ update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ var origBytes [header.IPv4MinimumSize]byte
+ header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
+ TOS: 1,
+ TotalLength: header.IPv4MinimumSize,
+ ID: 2,
+ Flags: 3,
+ FragmentOffset: 4,
+ TTL: 5,
+ Protocol: 6,
+ Checksum: 0,
+ SrcAddr: randomAddress(header.IPv4AddressSize),
+ DstAddr: randomAddress(header.IPv4AddressSize),
+ })
+
+ addr := randomAddress(header.IPv4AddressSize)
+
+ bytesCopy := origBytes
+ h := header.IPv4(bytesCopy[:])
+ origXSum := h.CalculateChecksum()
+ h.SetChecksum(^origXSum)
+
+ test.update(h, addr)
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+ want := h.CalculateChecksum()
+ if got != want {
+ t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
+ }
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePort(t *testing.T) {
+ // The fields in the pseudo header is not tested here so we just use 0.
+ const pseudoHeaderXSum = 0
+
+ tests := []struct {
+ name string
+ transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.TCP(make([]byte, header.TCPMinimumSize))
+ h.Encode(&header.TCPFields{
+ SrcPort: src,
+ DstPort: dst,
+ SeqNum: 1,
+ AckNum: 2,
+ DataOffset: header.TCPMinimumSize,
+ Flags: 3,
+ WindowSize: 4,
+ Checksum: 0,
+ UrgentPointer: 5,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.UDP(make([]byte, header.UDPMinimumSize))
+ h.Encode(&header.UDPFields{
+ SrcPort: src,
+ DstPort: dst,
+ Length: 0,
+ Checksum: 0,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ origSrcPort := uint16(rand.Uint32())
+ origDstPort := uint16(rand.Uint32())
+ newPort := uint16(rand.Uint32())
+
+ t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range []struct {
+ name string
+ update func(header.ChecksummableTransport)
+ }{
+ {
+ name: "Source port",
+ update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
+ },
+ {
+ name: "Destination port",
+ update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
+ },
+ } {
+ t.Run(subTest.name, func(t *testing.T) {
+ h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
+ subTest.update(h)
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+
+ if want := calcXSum(pseudoHeaderXSum); got != want {
+ h, _ := test.transportHdr(origSrcPort, origDstPort)
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
+ const addressSize = 6
+
+ tests := []struct {
+ name string
+ transportHdr func() header.ChecksummableTransport
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ permanent := randomAddress(addressSize)
+ old := randomAddress(addressSize)
+ new := randomAddress(addressSize)
+
+ t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, fullChecksum := range []bool{true, false} {
+ t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
+ initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
+ if fullChecksum {
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ initialXSum = ^initialXSum
+ }
+
+ h := test.transportHdr()
+ h.SetChecksum(initialXSum)
+ h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
+
+ got := h.Checksum()
+ if fullChecksum {
+ got = ^got
+ }
+ if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
index 861cbbb70..3a41adfc4 100644
--- a/pkg/tcpip/header/interfaces.go
+++ b/pkg/tcpip/header/interfaces.go
@@ -53,6 +53,31 @@ type Transport interface {
Payload() []byte
}
+// ChecksummableTransport is a Transport that supports checksumming.
+type ChecksummableTransport interface {
+ Transport
+
+ // SetSourcePortWithChecksumUpdate sets the source port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetSourcePortWithChecksumUpdate(port uint16)
+
+ // SetDestinationPortWithChecksumUpdate sets the destination port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetDestinationPortWithChecksumUpdate(port uint16)
+
+ // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
+ // updated address in the pseudo header.
+ //
+ // If fullChecksum is true, the receiver's checksum field is assumed to hold a
+ // fully calculated checksum. Otherwise, it is assumed to hold a partially
+ // calculated checksum which only reflects the pseudo header.
+ UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
+}
+
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
@@ -90,3 +115,16 @@ type Network interface {
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
+
+// ChecksummableNetwork is a Network that supports checksumming.
+type ChecksummableNetwork interface {
+ Network
+
+ // SetSourceAddressAndChecksum sets the source address and updates the
+ // checksum to reflect the new address.
+ SetSourceAddressWithChecksumUpdate(tcpip.Address)
+
+ // SetDestinationAddressAndChecksum sets the destination address and
+ // updates the checksum to reflect the new address.
+ SetDestinationAddressWithChecksumUpdate(tcpip.Address)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e9abbb709..dcc549c7b 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new))
+ b.SetSourceAddress(new)
+}
+
+// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new))
+ b.SetDestinationAddress(new)
+}
+
// padIPv4OptionsLength returns the total length for IPv4 options of length l
// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index b1f39e6e6..a647ea968 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -233,6 +233,17 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
case ndpNonceOptionType:
return NDPNonceOption(body), false, nil
+ case ndpRouteInformationType:
+ if numBodyBytes > ndpRouteInformationMaxLength {
+ return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody)
+ }
+ opt := NDPRouteInformation(body)
+ if err := opt.hasError(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
case ndpPrefixInformationType:
// Make sure the length of a Prefix Information option
// body is ndpPrefixInformationLength, as per RFC 4861
@@ -930,3 +941,137 @@ func isUpperLetter(b byte) bool {
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
+
+// As per RFC 4191 section 2.3,
+//
+// 2.3. Route Information Option
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Route Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Prefix (Variable Length) |
+// . .
+// . .
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
+// Fields:
+//
+// Type 24
+//
+//
+// Length 8-bit unsigned integer. The length of the option
+// (including the Type and Length fields) in units of 8
+// octets. The Length field is 1, 2, or 3 depending on the
+// Prefix Length. If Prefix Length is greater than 64, then
+// Length must be 3. If Prefix Length is greater than 0,
+// then Length must be 2 or 3. If Prefix Length is zero,
+// then Length must be 1, 2, or 3.
+const (
+ ndpRouteInformationType = ndpOptionIdentifier(24)
+ ndpRouteInformationMaxLength = 22
+
+ ndpRouteInformationPrefixLengthIdx = 0
+ ndpRouteInformationFlagsIdx = 1
+ ndpRouteInformationPrfShift = 3
+ ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift
+ ndpRouteInformationRouteLifetimeIdx = 2
+ ndpRouteInformationRoutePrefixIdx = 6
+)
+
+// NDPRouteInformation is the NDP Router Information option, as defined by
+// RFC 4191 section 2.3.
+type NDPRouteInformation []byte
+
+func (NDPRouteInformation) kind() ndpOptionIdentifier {
+ return ndpRouteInformationType
+}
+
+func (o NDPRouteInformation) length() int {
+ return len(o)
+}
+
+func (o NDPRouteInformation) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.
+func (o NDPRouteInformation) String() string {
+ return fmt.Sprintf("%T", o)
+}
+
+// PrefixLength returns the length of the prefix.
+func (o NDPRouteInformation) PrefixLength() uint8 {
+ return o[ndpRouteInformationPrefixLengthIdx]
+}
+
+// RoutePreference returns the preference of the route over other routes to the
+// same destination but through a different router.
+func (o NDPRouteInformation) RoutePreference() NDPRoutePreference {
+ return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift)
+}
+
+// RouteLifetime returns the lifetime of the route.
+//
+// Note, a value of 0 implies the route is now invalid and a value of
+// infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPRouteInformation) RouteLifetime() time.Duration {
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:]))
+}
+
+// Prefix returns the prefix of the destination subnet this route is for.
+func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) {
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max)
+ }
+
+ prefix := o[ndpRouteInformationRoutePrefixIdx:]
+ var addrBytes [IPv6AddressSize]byte
+ if n := copy(addrBytes[:], prefix); n != len(prefix) {
+ panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix)))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes[:]),
+ PrefixLen: prefixLength,
+ }.Subnet(), nil
+}
+
+func (o NDPRouteInformation) hasError() error {
+ l := len(o)
+ if l < ndpRouteInformationRoutePrefixIdx {
+ return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody)
+ }
+
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody)
+ }
+
+ // Length 8-bit unsigned integer. The length of the option
+ // (including the Type and Length fields) in units of 8
+ // octets. The Length field is 1, 2, or 3 depending on the
+ // Prefix Length. If Prefix Length is greater than 64, then
+ // Length must be 3. If Prefix Length is greater than 0,
+ // then Length must be 2 or 3. If Prefix Length is zero,
+ // then Length must be 1, 2, or 3.
+ l += 2 // Add 2 bytes for the type and length bytes.
+ lengthField := l / lengthByteUnits
+ if prefixLength > 64 {
+ if lengthField != 3 {
+ return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if prefixLength > 0 {
+ if lengthField != 2 && lengthField != 3 {
+ return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if lengthField == 0 || lengthField > 3 {
+ return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index 7e2f0c797..7d6efa083 100644
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -16,9 +16,12 @@ package header
import (
"encoding/binary"
+ "fmt"
"time"
)
+var _ fmt.Stringer = NDPRoutePreference(0)
+
// NDPRoutePreference is the preference values for default routers or
// more-specific routes.
//
@@ -64,6 +67,22 @@ const (
ReservedRoutePreference = 0b10
)
+// String implements fmt.Stringer.
+func (p NDPRoutePreference) String() string {
+ switch p {
+ case HighRoutePreference:
+ return "HighRoutePreference"
+ case MediumRoutePreference:
+ return "MediumRoutePreference"
+ case LowRoutePreference:
+ return "LowRoutePreference"
+ case ReservedRoutePreference:
+ return "ReservedRoutePreference"
+ default:
+ return fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+}
+
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 8fd1f7d13..2a897e938 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"regexp"
+ "strings"
"testing"
"time"
@@ -58,6 +59,224 @@ func TestNDPNeighborSolicit(t *testing.T) {
}
}
+func TestNDPRouteInformationOption(t *testing.T) {
+ tests := []struct {
+ name string
+
+ length uint8
+ prefixLength uint8
+ prf NDPRoutePreference
+ lifetimeS uint32
+ prefixBytes []byte
+ expectedPrefix tcpip.Subnet
+
+ expectedErr error
+ }{
+ {
+ name: "Length=1 with Prefix Length = 0",
+ length: 1,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=1 but Prefix Length > 0",
+ length: 1,
+ prefixLength: 1,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=2 with Prefix Length = 0",
+ length: 2,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (1)",
+ length: 2,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (64)",
+ length: 2,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length > 64",
+ length: 2,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=3 with Prefix Length = 0",
+ length: 3,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (1)",
+ length: 3,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (64)",
+ length: 3,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (65)",
+ length: 3,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 65,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (128)",
+ length: 3,
+ prefixLength: 128,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 128,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with (invalid) Prefix Length > 128",
+ length: 3,
+ prefixLength: 129,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ expectedRouteInformationBytes := [...]byte{
+ // Type, Length
+ 24, test.length,
+
+ // Prefix Length, Prf
+ uint8(test.prefixLength), uint8(test.prf) << 3,
+
+ // Route Lifetime
+ 0, 0, 0, 0,
+
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ }
+ binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS)
+ _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes)
+
+ opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits])
+ it, err := opts.Iter(false)
+ if err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ opt, done, err := it.Next()
+ if !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr)
+ }
+ if want := test.expectedErr != nil; done != want {
+ t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want)
+ }
+ if test.expectedErr != nil {
+ return
+ }
+
+ if got := opt.kind(); got != ndpRouteInformationType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType)
+ }
+
+ ri, ok := opt.(NDPRouteInformation)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPRouteInformation", opt)
+ }
+
+ if got := ri.PrefixLength(); got != test.prefixLength {
+ t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength)
+ }
+ if got := ri.RoutePreference(); got != test.prf {
+ t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf)
+ }
+ if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want {
+ t.Errorf("got RouteLifetime() = %s, want = %s", got, want)
+ }
+ if got, err := ri.Prefix(); err != nil {
+ t.Errorf("Prefix(): %s", err)
+ } else if got != test.expectedPrefix {
+ t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix)
+ }
+
+ // Iterator should not return anything else.
+ {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next() = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
+ }
+ }
+ })
+ }
+}
+
// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
func TestNDPNeighborAdvert(t *testing.T) {
b := []byte{
@@ -1498,3 +1717,32 @@ func TestNDPOptionsIter(t *testing.T) {
t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
}
}
+
+func TestNDPRoutePreferenceStringer(t *testing.T) {
+ p := NDPRoutePreference(0)
+ for {
+ var wantStr string
+ switch p {
+ case 0b01:
+ wantStr = "HighRoutePreference"
+ case 0b00:
+ wantStr = "MediumRoutePreference"
+ case 0b11:
+ wantStr = "LowRoutePreference"
+ case 0b10:
+ wantStr = "ReservedRoutePreference"
+ default:
+ wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+
+ if gotStr := p.String(); gotStr != wantStr {
+ t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr)
+ }
+
+ p++
+ if p == 0 {
+ // Overflowed, we hit all values.
+ break
+ }
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 8dabe3354..a75e51a28 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
b.SetChecksum(^checksum)
}
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
+
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index ae9d167ff..f69d53314 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
+
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}