diff options
Diffstat (limited to 'pkg/tcpip/header')
-rw-r--r-- | pkg/tcpip/header/checksum.go | 62 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 203 | ||||
-rw-r--r-- | pkg/tcpip/header/interfaces.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/header/ndp_options.go | 145 | ||||
-rw-r--r-- | pkg/tcpip/header/ndp_router_advert.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/header/ndp_test.go | 248 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/header/udp.go | 29 |
9 files changed, 785 insertions, 0 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 6aa9acfa8..e2c85e220 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -18,6 +18,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip. return Checksum([]byte{0, uint8(protocol)}, xsum) } + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + return ChecksumCombine(xsum, ChecksumCombine(new, ^old)) +} + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { + const uint16Bytes = 2 + + if len(old) != len(new) { + panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new))) + } + + if len(old)%uint16Bytes != 0 { + panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old))) + } + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(old) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1])) + old = old[uint16Bytes:] + new = new[uint16Bytes:] + } + + return xsum +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index d267dabd0..3445511f4 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -23,6 +23,7 @@ import ( "sync" "testing" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) { }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } + +func randomAddress(size int) tcpip.Address { + s := make([]byte, size) + for i := 0; i < size; i++ { + s[i] = byte(rand.Uint32()) + } + return tcpip.Address(s) +} + +func TestChecksummableNetworkUpdateAddress(t *testing.T) { + tests := []struct { + name string + update func(header.IPv4, tcpip.Address) + }{ + { + name: "SetSourceAddressWithChecksumUpdate", + update: header.IPv4.SetSourceAddressWithChecksumUpdate, + }, + { + name: "SetDestinationAddressWithChecksumUpdate", + update: header.IPv4.SetDestinationAddressWithChecksumUpdate, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + var origBytes [header.IPv4MinimumSize]byte + header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ + TOS: 1, + TotalLength: header.IPv4MinimumSize, + ID: 2, + Flags: 3, + FragmentOffset: 4, + TTL: 5, + Protocol: 6, + Checksum: 0, + SrcAddr: randomAddress(header.IPv4AddressSize), + DstAddr: randomAddress(header.IPv4AddressSize), + }) + + addr := randomAddress(header.IPv4AddressSize) + + bytesCopy := origBytes + h := header.IPv4(bytesCopy[:]) + origXSum := h.CalculateChecksum() + h.SetChecksum(^origXSum) + + test.update(h, addr) + got := ^h.Checksum() + h.SetChecksum(0) + want := h.CalculateChecksum() + if got != want { + t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) + } + } + }) + } +} + +func TestChecksummableTransportUpdatePort(t *testing.T) { + // The fields in the pseudo header is not tested here so we just use 0. + const pseudoHeaderXSum = 0 + + tests := []struct { + name string + transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.TCP(make([]byte, header.TCPMinimumSize)) + h.Encode(&header.TCPFields{ + SrcPort: src, + DstPort: dst, + SeqNum: 1, + AckNum: 2, + DataOffset: header.TCPMinimumSize, + Flags: 3, + WindowSize: 4, + Checksum: 0, + UrgentPointer: 5, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.UDP(make([]byte, header.UDPMinimumSize)) + h.Encode(&header.UDPFields{ + SrcPort: src, + DstPort: dst, + Length: 0, + Checksum: 0, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + origSrcPort := uint16(rand.Uint32()) + origDstPort := uint16(rand.Uint32()) + newPort := uint16(rand.Uint32()) + + t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range []struct { + name string + update func(header.ChecksummableTransport) + }{ + { + name: "Source port", + update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, + }, + { + name: "Destination port", + update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, + }, + } { + t.Run(subTest.name, func(t *testing.T) { + h, calcXSum := test.transportHdr(origSrcPort, origDstPort) + subTest.update(h) + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + got := ^h.Checksum() + h.SetChecksum(0) + + if want := calcXSum(pseudoHeaderXSum); got != want { + h, _ := test.transportHdr(origSrcPort, origDstPort) + t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) + } + }) + } + }) + } + }) + } +} + +func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { + const addressSize = 6 + + tests := []struct { + name string + transportHdr func() header.ChecksummableTransport + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + permanent := randomAddress(addressSize) + old := randomAddress(addressSize) + new := randomAddress(addressSize) + + t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, fullChecksum := range []bool{true, false} { + t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { + initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) + if fullChecksum { + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + initialXSum = ^initialXSum + } + + h := test.transportHdr() + h.SetChecksum(initialXSum) + h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) + + got := h.Checksum() + if fullChecksum { + got = ^got + } + if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { + t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go index 861cbbb70..3a41adfc4 100644 --- a/pkg/tcpip/header/interfaces.go +++ b/pkg/tcpip/header/interfaces.go @@ -53,6 +53,31 @@ type Transport interface { Payload() []byte } +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + // Network offers generic methods to query and/or update the fields of the // header of a network protocol buffer. type Network interface { @@ -90,3 +115,16 @@ type Network interface { // SetTOS sets the values of the "type of service" and "flow label" fields. SetTOS(t uint8, l uint32) } + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e9abbb709..dcc549c7b 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new)) + b.SetSourceAddress(new) +} + +// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new)) + b.SetDestinationAddress(new) +} + // padIPv4OptionsLength returns the total length for IPv4 options of length l // after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index b1f39e6e6..a647ea968 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -233,6 +233,17 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { case ndpNonceOptionType: return NDPNonceOption(body), false, nil + case ndpRouteInformationType: + if numBodyBytes > ndpRouteInformationMaxLength { + return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody) + } + opt := NDPRouteInformation(body) + if err := opt.hasError(); err != nil { + return nil, true, err + } + + return opt, false, nil + case ndpPrefixInformationType: // Make sure the length of a Prefix Information option // body is ndpPrefixInformationLength, as per RFC 4861 @@ -930,3 +941,137 @@ func isUpperLetter(b byte) bool { func isDigit(b byte) bool { return b >= '0' && b <= '9' } + +// As per RFC 4191 section 2.3, +// +// 2.3. Route Information Option +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Length | Prefix Length |Resvd|Prf|Resvd| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Route Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Prefix (Variable Length) | +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Fields: +// +// Type 24 +// +// +// Length 8-bit unsigned integer. The length of the option +// (including the Type and Length fields) in units of 8 +// octets. The Length field is 1, 2, or 3 depending on the +// Prefix Length. If Prefix Length is greater than 64, then +// Length must be 3. If Prefix Length is greater than 0, +// then Length must be 2 or 3. If Prefix Length is zero, +// then Length must be 1, 2, or 3. +const ( + ndpRouteInformationType = ndpOptionIdentifier(24) + ndpRouteInformationMaxLength = 22 + + ndpRouteInformationPrefixLengthIdx = 0 + ndpRouteInformationFlagsIdx = 1 + ndpRouteInformationPrfShift = 3 + ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift + ndpRouteInformationRouteLifetimeIdx = 2 + ndpRouteInformationRoutePrefixIdx = 6 +) + +// NDPRouteInformation is the NDP Router Information option, as defined by +// RFC 4191 section 2.3. +type NDPRouteInformation []byte + +func (NDPRouteInformation) kind() ndpOptionIdentifier { + return ndpRouteInformationType +} + +func (o NDPRouteInformation) length() int { + return len(o) +} + +func (o NDPRouteInformation) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPRouteInformation) String() string { + return fmt.Sprintf("%T", o) +} + +// PrefixLength returns the length of the prefix. +func (o NDPRouteInformation) PrefixLength() uint8 { + return o[ndpRouteInformationPrefixLengthIdx] +} + +// RoutePreference returns the preference of the route over other routes to the +// same destination but through a different router. +func (o NDPRouteInformation) RoutePreference() NDPRoutePreference { + return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift) +} + +// RouteLifetime returns the lifetime of the route. +// +// Note, a value of 0 implies the route is now invalid and a value of +// infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPRouteInformation) RouteLifetime() time.Duration { + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:])) +} + +// Prefix returns the prefix of the destination subnet this route is for. +func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) { + prefixLength := int(o.PrefixLength()) + if max := IPv6AddressSize * 8; prefixLength > max { + return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max) + } + + prefix := o[ndpRouteInformationRoutePrefixIdx:] + var addrBytes [IPv6AddressSize]byte + if n := copy(addrBytes[:], prefix); n != len(prefix) { + panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix))) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes[:]), + PrefixLen: prefixLength, + }.Subnet(), nil +} + +func (o NDPRouteInformation) hasError() error { + l := len(o) + if l < ndpRouteInformationRoutePrefixIdx { + return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody) + } + + prefixLength := int(o.PrefixLength()) + if max := IPv6AddressSize * 8; prefixLength > max { + return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody) + } + + // Length 8-bit unsigned integer. The length of the option + // (including the Type and Length fields) in units of 8 + // octets. The Length field is 1, 2, or 3 depending on the + // Prefix Length. If Prefix Length is greater than 64, then + // Length must be 3. If Prefix Length is greater than 0, + // then Length must be 2 or 3. If Prefix Length is zero, + // then Length must be 1, 2, or 3. + l += 2 // Add 2 bytes for the type and length bytes. + lengthField := l / lengthByteUnits + if prefixLength > 64 { + if lengthField != 3 { + return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody) + } + } else if prefixLength > 0 { + if lengthField != 2 && lengthField != 3 { + return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody) + } + } else if lengthField == 0 || lengthField > 3 { + return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody) + } + + return nil +} diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index 7e2f0c797..7d6efa083 100644 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -16,9 +16,12 @@ package header import ( "encoding/binary" + "fmt" "time" ) +var _ fmt.Stringer = NDPRoutePreference(0) + // NDPRoutePreference is the preference values for default routers or // more-specific routes. // @@ -64,6 +67,22 @@ const ( ReservedRoutePreference = 0b10 ) +// String implements fmt.Stringer. +func (p NDPRoutePreference) String() string { + switch p { + case HighRoutePreference: + return "HighRoutePreference" + case MediumRoutePreference: + return "MediumRoutePreference" + case LowRoutePreference: + return "LowRoutePreference" + case ReservedRoutePreference: + return "ReservedRoutePreference" + default: + return fmt.Sprintf("NDPRoutePreference(%d)", p) + } +} + // NDPRouterAdvert is an NDP Router Advertisement message. It will only contain // the body of an ICMPv6 packet. // diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 8fd1f7d13..2a897e938 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "regexp" + "strings" "testing" "time" @@ -58,6 +59,224 @@ func TestNDPNeighborSolicit(t *testing.T) { } } +func TestNDPRouteInformationOption(t *testing.T) { + tests := []struct { + name string + + length uint8 + prefixLength uint8 + prf NDPRoutePreference + lifetimeS uint32 + prefixBytes []byte + expectedPrefix tcpip.Subnet + + expectedErr error + }{ + { + name: "Length=1 with Prefix Length = 0", + length: 1, + prefixLength: 0, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: IPv6EmptySubnet, + }, + { + name: "Length=1 but Prefix Length > 0", + length: 1, + prefixLength: 1, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "Length=2 with Prefix Length = 0", + length: 2, + prefixLength: 0, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: IPv6EmptySubnet, + }, + { + name: "Length=2 with Prefix Length in [1, 64] (1)", + length: 2, + prefixLength: 1, + prf: LowRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 1, + }.Subnet(), + }, + { + name: "Length=2 with Prefix Length in [1, 64] (64)", + length: 2, + prefixLength: 64, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 64, + }.Subnet(), + }, + { + name: "Length=2 with Prefix Length > 64", + length: 2, + prefixLength: 65, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "Length=3 with Prefix Length = 0", + length: 3, + prefixLength: 0, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: IPv6EmptySubnet, + }, + { + name: "Length=3 with Prefix Length in [1, 64] (1)", + length: 3, + prefixLength: 1, + prf: LowRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 1, + }.Subnet(), + }, + { + name: "Length=3 with Prefix Length in [1, 64] (64)", + length: 3, + prefixLength: 64, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 64, + }.Subnet(), + }, + { + name: "Length=3 with Prefix Length in [65, 128] (65)", + length: 3, + prefixLength: 65, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 65, + }.Subnet(), + }, + { + name: "Length=3 with Prefix Length in [65, 128] (128)", + length: 3, + prefixLength: 128, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 128, + }.Subnet(), + }, + { + name: "Length=3 with (invalid) Prefix Length > 128", + length: 3, + prefixLength: 129, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedErr: ErrNDPOptMalformedBody, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + expectedRouteInformationBytes := [...]byte{ + // Type, Length + 24, test.length, + + // Prefix Length, Prf + uint8(test.prefixLength), uint8(test.prf) << 3, + + // Route Lifetime + 0, 0, 0, 0, + + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + } + binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS) + _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes) + + opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits]) + it, err := opts.Iter(false) + if err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + opt, done, err := it.Next() + if !errors.Is(err, test.expectedErr) { + t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr) + } + if want := test.expectedErr != nil; done != want { + t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want) + } + if test.expectedErr != nil { + return + } + + if got := opt.kind(); got != ndpRouteInformationType { + t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType) + } + + ri, ok := opt.(NDPRouteInformation) + if !ok { + t.Fatalf("got opt = %T, want = NDPRouteInformation", opt) + } + + if got := ri.PrefixLength(); got != test.prefixLength { + t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength) + } + if got := ri.RoutePreference(); got != test.prf { + t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf) + } + if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want { + t.Errorf("got RouteLifetime() = %s, want = %s", got, want) + } + if got, err := ri.Prefix(); err != nil { + t.Errorf("Prefix(): %s", err) + } else if got != test.expectedPrefix { + t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix) + } + + // Iterator should not return anything else. + { + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next() = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) + } + } + }) + } +} + // TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. func TestNDPNeighborAdvert(t *testing.T) { b := []byte{ @@ -1498,3 +1717,32 @@ func TestNDPOptionsIter(t *testing.T) { t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) } } + +func TestNDPRoutePreferenceStringer(t *testing.T) { + p := NDPRoutePreference(0) + for { + var wantStr string + switch p { + case 0b01: + wantStr = "HighRoutePreference" + case 0b00: + wantStr = "MediumRoutePreference" + case 0b11: + wantStr = "LowRoutePreference" + case 0b10: + wantStr = "ReservedRoutePreference" + default: + wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p) + } + + if gotStr := p.String(); gotStr != wantStr { + t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr) + } + + p++ + if p == 0 { + // Overflowed, we hit all values. + break + } + } +} diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 8dabe3354..a75e51a28 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 b.SetChecksum(^checksum) } +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + // ParseSynOptions parses the options received in a SYN segment and returns the // relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index ae9d167ff..f69d53314 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpLength:], u.Length) binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) } + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} |