diff options
Diffstat (limited to 'pkg/tcpip')
56 files changed, 1859 insertions, 784 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index ed4d7e958..f00cfd0f5 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -46,7 +46,6 @@ deps_test( "//pkg/gohacks", "//pkg/goid", "//pkg/ilist", - "//pkg/iovec", "//pkg/linewriter", "//pkg/log", "//pkg/rand", diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 6aa9acfa8..e2c85e220 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -18,6 +18,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip. return Checksum([]byte{0, uint8(protocol)}, xsum) } + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + return ChecksumCombine(xsum, ChecksumCombine(new, ^old)) +} + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { + const uint16Bytes = 2 + + if len(old) != len(new) { + panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new))) + } + + if len(old)%uint16Bytes != 0 { + panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old))) + } + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(old) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1])) + old = old[uint16Bytes:] + new = new[uint16Bytes:] + } + + return xsum +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index d267dabd0..3445511f4 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -23,6 +23,7 @@ import ( "sync" "testing" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) { }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } + +func randomAddress(size int) tcpip.Address { + s := make([]byte, size) + for i := 0; i < size; i++ { + s[i] = byte(rand.Uint32()) + } + return tcpip.Address(s) +} + +func TestChecksummableNetworkUpdateAddress(t *testing.T) { + tests := []struct { + name string + update func(header.IPv4, tcpip.Address) + }{ + { + name: "SetSourceAddressWithChecksumUpdate", + update: header.IPv4.SetSourceAddressWithChecksumUpdate, + }, + { + name: "SetDestinationAddressWithChecksumUpdate", + update: header.IPv4.SetDestinationAddressWithChecksumUpdate, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + var origBytes [header.IPv4MinimumSize]byte + header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ + TOS: 1, + TotalLength: header.IPv4MinimumSize, + ID: 2, + Flags: 3, + FragmentOffset: 4, + TTL: 5, + Protocol: 6, + Checksum: 0, + SrcAddr: randomAddress(header.IPv4AddressSize), + DstAddr: randomAddress(header.IPv4AddressSize), + }) + + addr := randomAddress(header.IPv4AddressSize) + + bytesCopy := origBytes + h := header.IPv4(bytesCopy[:]) + origXSum := h.CalculateChecksum() + h.SetChecksum(^origXSum) + + test.update(h, addr) + got := ^h.Checksum() + h.SetChecksum(0) + want := h.CalculateChecksum() + if got != want { + t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) + } + } + }) + } +} + +func TestChecksummableTransportUpdatePort(t *testing.T) { + // The fields in the pseudo header is not tested here so we just use 0. + const pseudoHeaderXSum = 0 + + tests := []struct { + name string + transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.TCP(make([]byte, header.TCPMinimumSize)) + h.Encode(&header.TCPFields{ + SrcPort: src, + DstPort: dst, + SeqNum: 1, + AckNum: 2, + DataOffset: header.TCPMinimumSize, + Flags: 3, + WindowSize: 4, + Checksum: 0, + UrgentPointer: 5, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.UDP(make([]byte, header.UDPMinimumSize)) + h.Encode(&header.UDPFields{ + SrcPort: src, + DstPort: dst, + Length: 0, + Checksum: 0, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + origSrcPort := uint16(rand.Uint32()) + origDstPort := uint16(rand.Uint32()) + newPort := uint16(rand.Uint32()) + + t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range []struct { + name string + update func(header.ChecksummableTransport) + }{ + { + name: "Source port", + update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, + }, + { + name: "Destination port", + update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, + }, + } { + t.Run(subTest.name, func(t *testing.T) { + h, calcXSum := test.transportHdr(origSrcPort, origDstPort) + subTest.update(h) + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + got := ^h.Checksum() + h.SetChecksum(0) + + if want := calcXSum(pseudoHeaderXSum); got != want { + h, _ := test.transportHdr(origSrcPort, origDstPort) + t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) + } + }) + } + }) + } + }) + } +} + +func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { + const addressSize = 6 + + tests := []struct { + name string + transportHdr func() header.ChecksummableTransport + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + permanent := randomAddress(addressSize) + old := randomAddress(addressSize) + new := randomAddress(addressSize) + + t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, fullChecksum := range []bool{true, false} { + t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { + initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) + if fullChecksum { + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + initialXSum = ^initialXSum + } + + h := test.transportHdr() + h.SetChecksum(initialXSum) + h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) + + got := h.Checksum() + if fullChecksum { + got = ^got + } + if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { + t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go index 861cbbb70..3a41adfc4 100644 --- a/pkg/tcpip/header/interfaces.go +++ b/pkg/tcpip/header/interfaces.go @@ -53,6 +53,31 @@ type Transport interface { Payload() []byte } +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + // Network offers generic methods to query and/or update the fields of the // header of a network protocol buffer. type Network interface { @@ -90,3 +115,16 @@ type Network interface { // SetTOS sets the values of the "type of service" and "flow label" fields. SetTOS(t uint8, l uint32) } + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index e9abbb709..dcc549c7b 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new)) + b.SetSourceAddress(new) +} + +// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new)) + b.SetDestinationAddress(new) +} + // padIPv4OptionsLength returns the total length for IPv4 options of length l // after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index b1f39e6e6..a647ea968 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -233,6 +233,17 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { case ndpNonceOptionType: return NDPNonceOption(body), false, nil + case ndpRouteInformationType: + if numBodyBytes > ndpRouteInformationMaxLength { + return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody) + } + opt := NDPRouteInformation(body) + if err := opt.hasError(); err != nil { + return nil, true, err + } + + return opt, false, nil + case ndpPrefixInformationType: // Make sure the length of a Prefix Information option // body is ndpPrefixInformationLength, as per RFC 4861 @@ -930,3 +941,137 @@ func isUpperLetter(b byte) bool { func isDigit(b byte) bool { return b >= '0' && b <= '9' } + +// As per RFC 4191 section 2.3, +// +// 2.3. Route Information Option +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Length | Prefix Length |Resvd|Prf|Resvd| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Route Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Prefix (Variable Length) | +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Fields: +// +// Type 24 +// +// +// Length 8-bit unsigned integer. The length of the option +// (including the Type and Length fields) in units of 8 +// octets. The Length field is 1, 2, or 3 depending on the +// Prefix Length. If Prefix Length is greater than 64, then +// Length must be 3. If Prefix Length is greater than 0, +// then Length must be 2 or 3. If Prefix Length is zero, +// then Length must be 1, 2, or 3. +const ( + ndpRouteInformationType = ndpOptionIdentifier(24) + ndpRouteInformationMaxLength = 22 + + ndpRouteInformationPrefixLengthIdx = 0 + ndpRouteInformationFlagsIdx = 1 + ndpRouteInformationPrfShift = 3 + ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift + ndpRouteInformationRouteLifetimeIdx = 2 + ndpRouteInformationRoutePrefixIdx = 6 +) + +// NDPRouteInformation is the NDP Router Information option, as defined by +// RFC 4191 section 2.3. +type NDPRouteInformation []byte + +func (NDPRouteInformation) kind() ndpOptionIdentifier { + return ndpRouteInformationType +} + +func (o NDPRouteInformation) length() int { + return len(o) +} + +func (o NDPRouteInformation) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer. +func (o NDPRouteInformation) String() string { + return fmt.Sprintf("%T", o) +} + +// PrefixLength returns the length of the prefix. +func (o NDPRouteInformation) PrefixLength() uint8 { + return o[ndpRouteInformationPrefixLengthIdx] +} + +// RoutePreference returns the preference of the route over other routes to the +// same destination but through a different router. +func (o NDPRouteInformation) RoutePreference() NDPRoutePreference { + return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift) +} + +// RouteLifetime returns the lifetime of the route. +// +// Note, a value of 0 implies the route is now invalid and a value of +// infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPRouteInformation) RouteLifetime() time.Duration { + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:])) +} + +// Prefix returns the prefix of the destination subnet this route is for. +func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) { + prefixLength := int(o.PrefixLength()) + if max := IPv6AddressSize * 8; prefixLength > max { + return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max) + } + + prefix := o[ndpRouteInformationRoutePrefixIdx:] + var addrBytes [IPv6AddressSize]byte + if n := copy(addrBytes[:], prefix); n != len(prefix) { + panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix))) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes[:]), + PrefixLen: prefixLength, + }.Subnet(), nil +} + +func (o NDPRouteInformation) hasError() error { + l := len(o) + if l < ndpRouteInformationRoutePrefixIdx { + return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody) + } + + prefixLength := int(o.PrefixLength()) + if max := IPv6AddressSize * 8; prefixLength > max { + return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody) + } + + // Length 8-bit unsigned integer. The length of the option + // (including the Type and Length fields) in units of 8 + // octets. The Length field is 1, 2, or 3 depending on the + // Prefix Length. If Prefix Length is greater than 64, then + // Length must be 3. If Prefix Length is greater than 0, + // then Length must be 2 or 3. If Prefix Length is zero, + // then Length must be 1, 2, or 3. + l += 2 // Add 2 bytes for the type and length bytes. + lengthField := l / lengthByteUnits + if prefixLength > 64 { + if lengthField != 3 { + return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody) + } + } else if prefixLength > 0 { + if lengthField != 2 && lengthField != 3 { + return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody) + } + } else if lengthField == 0 || lengthField > 3 { + return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody) + } + + return nil +} diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index 7e2f0c797..7d6efa083 100644 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -16,9 +16,12 @@ package header import ( "encoding/binary" + "fmt" "time" ) +var _ fmt.Stringer = NDPRoutePreference(0) + // NDPRoutePreference is the preference values for default routers or // more-specific routes. // @@ -64,6 +67,22 @@ const ( ReservedRoutePreference = 0b10 ) +// String implements fmt.Stringer. +func (p NDPRoutePreference) String() string { + switch p { + case HighRoutePreference: + return "HighRoutePreference" + case MediumRoutePreference: + return "MediumRoutePreference" + case LowRoutePreference: + return "LowRoutePreference" + case ReservedRoutePreference: + return "ReservedRoutePreference" + default: + return fmt.Sprintf("NDPRoutePreference(%d)", p) + } +} + // NDPRouterAdvert is an NDP Router Advertisement message. It will only contain // the body of an ICMPv6 packet. // diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 8fd1f7d13..2a897e938 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "regexp" + "strings" "testing" "time" @@ -58,6 +59,224 @@ func TestNDPNeighborSolicit(t *testing.T) { } } +func TestNDPRouteInformationOption(t *testing.T) { + tests := []struct { + name string + + length uint8 + prefixLength uint8 + prf NDPRoutePreference + lifetimeS uint32 + prefixBytes []byte + expectedPrefix tcpip.Subnet + + expectedErr error + }{ + { + name: "Length=1 with Prefix Length = 0", + length: 1, + prefixLength: 0, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: IPv6EmptySubnet, + }, + { + name: "Length=1 but Prefix Length > 0", + length: 1, + prefixLength: 1, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "Length=2 with Prefix Length = 0", + length: 2, + prefixLength: 0, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: IPv6EmptySubnet, + }, + { + name: "Length=2 with Prefix Length in [1, 64] (1)", + length: 2, + prefixLength: 1, + prf: LowRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 1, + }.Subnet(), + }, + { + name: "Length=2 with Prefix Length in [1, 64] (64)", + length: 2, + prefixLength: 64, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 64, + }.Subnet(), + }, + { + name: "Length=2 with Prefix Length > 64", + length: 2, + prefixLength: 65, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "Length=3 with Prefix Length = 0", + length: 3, + prefixLength: 0, + prf: MediumRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: IPv6EmptySubnet, + }, + { + name: "Length=3 with Prefix Length in [1, 64] (1)", + length: 3, + prefixLength: 1, + prf: LowRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 1, + }.Subnet(), + }, + { + name: "Length=3 with Prefix Length in [1, 64] (64)", + length: 3, + prefixLength: 64, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 64, + }.Subnet(), + }, + { + name: "Length=3 with Prefix Length in [65, 128] (65)", + length: 3, + prefixLength: 65, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 65, + }.Subnet(), + }, + { + name: "Length=3 with Prefix Length in [65, 128] (128)", + length: 3, + prefixLength: 128, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)), + PrefixLen: 128, + }.Subnet(), + }, + { + name: "Length=3 with (invalid) Prefix Length > 128", + length: 3, + prefixLength: 129, + prf: HighRoutePreference, + lifetimeS: 1, + prefixBytes: nil, + expectedErr: ErrNDPOptMalformedBody, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + expectedRouteInformationBytes := [...]byte{ + // Type, Length + 24, test.length, + + // Prefix Length, Prf + uint8(test.prefixLength), uint8(test.prf) << 3, + + // Route Lifetime + 0, 0, 0, 0, + + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + } + binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS) + _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes) + + opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits]) + it, err := opts.Iter(false) + if err != nil { + t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err) + } + opt, done, err := it.Next() + if !errors.Is(err, test.expectedErr) { + t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr) + } + if want := test.expectedErr != nil; done != want { + t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want) + } + if test.expectedErr != nil { + return + } + + if got := opt.kind(); got != ndpRouteInformationType { + t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType) + } + + ri, ok := opt.(NDPRouteInformation) + if !ok { + t.Fatalf("got opt = %T, want = NDPRouteInformation", opt) + } + + if got := ri.PrefixLength(); got != test.prefixLength { + t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength) + } + if got := ri.RoutePreference(); got != test.prf { + t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf) + } + if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want { + t.Errorf("got RouteLifetime() = %s, want = %s", got, want) + } + if got, err := ri.Prefix(); err != nil { + t.Errorf("Prefix(): %s", err) + } else if got != test.expectedPrefix { + t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix) + } + + // Iterator should not return anything else. + { + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next() = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next) + } + } + }) + } +} + // TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert. func TestNDPNeighborAdvert(t *testing.T) { b := []byte{ @@ -1498,3 +1717,32 @@ func TestNDPOptionsIter(t *testing.T) { t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) } } + +func TestNDPRoutePreferenceStringer(t *testing.T) { + p := NDPRoutePreference(0) + for { + var wantStr string + switch p { + case 0b01: + wantStr = "HighRoutePreference" + case 0b00: + wantStr = "MediumRoutePreference" + case 0b11: + wantStr = "LowRoutePreference" + case 0b10: + wantStr = "ReservedRoutePreference" + default: + wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p) + } + + if gotStr := p.String(); gotStr != wantStr { + t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr) + } + + p++ + if p == 0 { + // Overflowed, we hit all values. + break + } + } +} diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 8dabe3354..a75e51a28 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 b.SetChecksum(^checksum) } +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + // ParseSynOptions parses the options received in a SYN segment and returns the // relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index ae9d167ff..f69d53314 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpLength:], u.Length) binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) } + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index d971194e6..1d0163823 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 735c28da1..e8e716db0 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux // Package fdbased provides the implemention of data-link layer endpoints @@ -44,7 +45,6 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -138,6 +138,20 @@ type endpoint struct { // gsoKind is the supported kind of GSO. gsoKind stack.SupportedGSO + + // maxSyscallHeaderBytes has the same meaning as + // Options.MaxSyscallHeaderBytes. + maxSyscallHeaderBytes uintptr + + // writevMaxIovs is the maximum number of iovecs that may be passed to + // rawfile.NonBlockingWriteIovec, as possibly limited by + // maxSyscallHeaderBytes. (No analogous limit is defined for + // rawfile.NonBlockingSendMMsg, since in that case the maximum number of + // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch + // encounters a packet whose iovec count is limited by + // maxSyscallHeaderBytes, it falls back to writing the packet using writev + // via WritePacket.) + writevMaxIovs int } // Options specify the details about the fd-based endpoint to be created. @@ -186,6 +200,11 @@ type Options struct { // RXChecksumOffload if true, indicates that this endpoints capability // set should include CapabilityRXChecksumOffload. RXChecksumOffload bool + + // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes + // of struct iovec, msghdr, and mmsghdr that may be passed by each host + // system call. + MaxSyscallHeaderBytes int } // fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT @@ -235,14 +254,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) { return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } + if opts.MaxSyscallHeaderBytes < 0 { + return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative") + } + e := &endpoint{ - fds: opts.FDs, - mtu: opts.MTU, - caps: caps, - closed: opts.ClosedFunc, - addr: opts.Address, - hdrSize: hdrSize, - packetDispatchMode: opts.PacketDispatchMode, + fds: opts.FDs, + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: hdrSize, + packetDispatchMode: opts.PacketDispatchMode, + maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes), + writevMaxIovs: rawfile.MaxIovs, + } + if e.maxSyscallHeaderBytes != 0 { + if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs { + e.writevMaxIovs = max + } } // Increment fanoutID to ensure that we don't re-use the same fanoutID for @@ -470,9 +500,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } - var builder iovec.Builder - fd := e.fds[pkt.Hash%uint32(len(e.fds))] + var vnetHdrBuf []byte if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { @@ -494,71 +523,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol vnetHdr.gsoSize = pkt.GSOOptions.MSS } } + vnetHdrBuf = vnetHdr.marshal() + } - vnetHdrBuf := vnetHdr.marshal() - builder.Add(vnetHdrBuf) + views := pkt.Views() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > e.writevMaxIovs { + numIovecs = e.writevMaxIovs } - for _, v := range pkt.Views() { - builder.Add(v) + // Allocate small iovec arrays on the stack. + var iovecsArr [8]unix.Iovec + iovecs := iovecsArr[:0] + if numIovecs > len(iovecsArr) { + iovecs = make([]unix.Iovec, 0, numIovecs) + } + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) + return rawfile.NonBlockingWriteIovec(fd, iovecs) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. - mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) - for _, pkt := range batch { - if e.hdrSize > 0 { - e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) - } + mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts)) + packets := 0 + for packets < len(pkts) { + mmsgHdrs := mmsgHdrsStorage + batch := pkts[packets:] + syscallHeaderBytes := uintptr(0) + for _, pkt := range batch { + if e.hdrSize > 0 { + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) + } - var vnetHdrBuf []byte - if e.gsoKind == stack.HWGSOSupported { - vnetHdr := virtioNetHdr{} - if pkt.GSOOptions.Type != stack.GSONone { - vnetHdr.hdrLen = uint16(pkt.HeaderSize()) - if pkt.GSOOptions.NeedsCsum { - vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen - vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset - } - if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { - switch pkt.GSOOptions.Type { - case stack.GSOTCPv4: - vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 - case stack.GSOTCPv6: - vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - default: - panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + var vnetHdrBuf []byte + if e.gsoKind == stack.HWGSOSupported { + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions.Type != stack.GSONone { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS } - vnetHdr.gsoSize = pkt.GSOOptions.MSS } + vnetHdrBuf = vnetHdr.marshal() } - vnetHdrBuf = vnetHdr.marshal() - } - var builder iovec.Builder - builder.Add(vnetHdrBuf) - for _, v := range pkt.Views() { - builder.Add(v) - } - iovecs := builder.Build() + views := pkt.Views() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > rawfile.MaxIovs { + numIovecs = rawfile.MaxIovs + } + if e.maxSyscallHeaderBytes != 0 { + syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec + if syscallHeaderBytes > e.maxSyscallHeaderBytes { + // We can't fit this packet into this call to sendmmsg(). + // We could potentially do so if we reduced numIovecs + // further, but this might incur considerable extra + // copying. Leave it to the next batch instead. + break + } + } - var mmsgHdr rawfile.MMsgHdr - mmsgHdr.Msg.Iov = &iovecs[0] - mmsgHdr.Msg.SetIovlen((len(iovecs))) - mmsgHdrs = append(mmsgHdrs, mmsgHdr) - } + // We can't easily allocate iovec arrays on the stack here since + // they will escape this loop iteration via mmsgHdrs. + iovecs := make([]unix.Iovec, 0, numIovecs) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } - packets := 0 - for len(mmsgHdrs) > 0 { - sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) - if err != nil { - return packets, err + var mmsgHdr rawfile.MMsgHdr + mmsgHdr.Msg.Iov = &iovecs[0] + mmsgHdr.Msg.SetIovlen(len(iovecs)) + mmsgHdrs = append(mmsgHdrs, mmsgHdr) + } + + if len(mmsgHdrs) == 0 { + // We can't fit batch[0] into a mmsghdr while staying under + // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the + // mmsghdr (by using writev) and re-buffer iovecs more aggressively + // if necessary (by using e.writevMaxIovs instead of + // rawfile.MaxIovs). + pkt := batch[0] + if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return packets, err + } + packets++ + } else { + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } } - packets += sent - mmsgHdrs = mmsgHdrs[sent:] } return packets, nil @@ -676,8 +757,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti unix.SetNonblock(fd, true) return &InjectableEndpoint{endpoint: endpoint{ - fds: []int{fd}, - mtu: mtu, - caps: capabilities, + fds: []int{fd}, + mtu: mtu, + caps: capabilities, + writevMaxIovs: rawfile.MaxIovs, }} } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 8aad338b6..eccd21579 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package fdbased diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go index df14eaad1..904393faa 100644 --- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go +++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package fdbased diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 5d698a5e9..bfae34ab9 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build (linux && amd64) || (linux && arm64) // +build linux,amd64 linux,arm64 package fdbased diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go index 67be52d67..9d8679502 100644 --- a/pkg/tcpip/link/fdbased/mmap_stub.go +++ b/pkg/tcpip/link/fdbased/mmap_stub.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build !linux || (!amd64 && !arm64) // +build !linux !amd64,!arm64 package fdbased diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 1293f68a2..58d5dfeef 100644 --- a/pkg/tcpip/link/fdbased/mmap_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build (linux && amd64) || (linux && arm64) // +build linux,amd64 linux,arm64 package fdbased diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 4b7ef3aac..ab2855a63 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package fdbased diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go index 2206fe0e6..c1438da21 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux && !amd64 && !arm64 // +build linux,!amd64,!arm64 package rawfile diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 5002245a1..da900c24b 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build ((linux && amd64) || (linux && arm64)) && go1.12 && !go1.18 // +build linux,amd64 linux,arm64 // +build go1.12 // +build !go1.18 diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index 9743e70ea..7e21a78d4 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package rawfile diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go index 8f4bd60da..1b88c309b 100644 --- a/pkg/tcpip/link/rawfile/errors_test.go +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package rawfile diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index ba92aedbc..53448a641 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux // Package rawfile contains utilities for using the netstack with raw host @@ -19,12 +20,66 @@ package rawfile import ( + "reflect" "unsafe" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) +// SizeofIovec is the size of a unix.Iovec in bytes. +const SizeofIovec = unsafe.Sizeof(unix.Iovec{}) + +// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a +// host system call in a single array. +const MaxIovs = 1024 + +// IovecFromBytes returns a unix.Iovec representing bs. +// +// Preconditions: len(bs) > 0. +func IovecFromBytes(bs []byte) unix.Iovec { + iov := unix.Iovec{ + Base: &bs[0], + } + iov.SetLen(len(bs)) + return iov +} + +func bytesFromIovec(iov unix.Iovec) (bs []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + sh.Data = uintptr(unsafe.Pointer(iov.Base)) + sh.Len = int(iov.Len) + sh.Cap = int(iov.Len) + return +} + +// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) == +// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >= +// max, AppendIovecFromBytes replaces the final iovec in iovs with one that +// also includes the contents of bs. Note that this implies that +// AppendIovecFromBytes is only usable when the returned iovec slice is used as +// the source of a write. +func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec { + if len(bs) == 0 { + return iovs + } + if len(iovs) < max { + return append(iovs, IovecFromBytes(bs)) + } + iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...)) + return iovs +} + +// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. +type MMsgHdr struct { + Msg unix.Msghdr + Len uint32 + _ [4]byte +} + +// SizeofMMsgHdr is the size of a MMsgHdr in bytes. +const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{}) + // GetMTU determines the MTU of a network interface device. func GetMTU(name string) (uint32, error) { fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0) @@ -137,13 +192,6 @@ func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) { } } -// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. -type MMsgHdr struct { - Msg unix.Msghdr - Len uint32 - _ [4]byte -} - // BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go index 8e6f3e5e3..e882a128c 100644 --- a/pkg/tcpip/link/sharedmem/rx.go +++ b/pkg/tcpip/link/sharedmem/rx.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package sharedmem diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index df9a0b90a..30cf659b8 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux // Package sharedmem provides the implemention of data-link layer endpoints diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 0f72d4e95..d6d953085 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux package sharedmem diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go index c16c19647..3bb864ed2 100644 --- a/pkg/tcpip/link/sniffer/pcap.go +++ b/pkg/tcpip/link/sniffer/pcap.go @@ -39,8 +39,6 @@ type pcapHeader struct { Network uint32 } -const pcapPacketHeaderLen = 16 - type pcapPacketHeader struct { // Seconds is the timestamp seconds. Seconds uint32 @@ -55,8 +53,7 @@ type pcapPacketHeader struct { OriginalLength uint32 } -func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader { - now := time.Now() +func newPCAPPacketHeader(now time.Time, incLen, orgLen uint32) pcapPacketHeader { return pcapPacketHeader{ Seconds: uint32(now.Unix()), Microseconds: uint32(now.Nanosecond() / 1000), diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 2d6a3a833..3df826f3c 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -87,11 +87,7 @@ func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoin } func zoneOffset() (int32, error) { - loc, err := time.LoadLocation("Local") - if err != nil { - return 0, err - } - date := time.Date(0, 0, 0, 0, 0, 0, 0, loc) + date := time.Date(0, 0, 0, 0, 0, 0, 0, time.Local) _, offset := date.Zone() return int32(offset), nil } @@ -117,8 +113,9 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error { // NewWithWriter creates a new sniffer link-layer endpoint. It wraps around // another endpoint and logs packets as they traverse the endpoint. // -// Packets are logged to writer in the pcap format. A sniffer created with this -// function will not emit packets using the standard log package. +// Each packet is written to writer in the pcap format in a single Write call +// without synchronization. A sniffer created with this function will not emit +// packets using the standard log package. // // snapLen is the maximum amount of a packet to be saved. Packets with a length // less than or equal to snapLen will be saved in their entirety. Longer @@ -159,27 +156,29 @@ func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumbe if max := int(e.maxPCAPLen); length > max { length = max } - if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil { - panic(err) - } - write := func(b []byte) { - if len(b) > length { - b = b[:length] + packetHeader := newPCAPPacketHeader(time.Now(), uint32(length), uint32(totalLength)) + packet := make([]byte, binary.Size(packetHeader)+length) + { + writer := tcpip.SliceWriter(packet) + if err := binary.Write(&writer, binary.BigEndian, packetHeader); err != nil { + panic(err) } - for len(b) != 0 { + for _, b := range pkt.Views() { + if length == 0 { + break + } + if len(b) > length { + b = b[:length] + } n, err := writer.Write(b) if err != nil { panic(err) } - b = b[n:] length -= n } } - for _, v := range pkt.Views() { - if length == 0 { - break - } - write(v) + if _, err := writer.Write(packet); err != nil { + panic(err) } } } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 7656cca6a..4758a99ad 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -26,6 +26,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 36af2a029..d23210503 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -88,12 +89,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error { defer d.mu.Unlock() if d.endpoint != nil { - return syserror.EINVAL + return linuxerr.EINVAL } // Input validation. if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN { - return syserror.EINVAL + return linuxerr.EINVAL } prefix := "tun" @@ -108,7 +109,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error { endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return syserror.EINVAL + return linuxerr.EINVAL } d.endpoint = endpoint @@ -125,7 +126,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, syserror.EOPNOTSUPP + return nil, linuxerr.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. @@ -159,7 +160,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // Race detected: A NIC has been created in between. continue default: - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } } @@ -170,7 +171,7 @@ func (d *Device) Write(data []byte) (int64, error) { endpoint := d.endpoint d.mu.RUnlock() if endpoint == nil { - return 0, syserror.EBADFD + return 0, linuxerr.EBADFD } if !endpoint.IsAttached() { return 0, syserror.EIO @@ -207,6 +208,15 @@ func (d *Device) Write(data []byte) (int64, error) { protocol = pktInfoHdr.Protocol() case ethHdr != nil: protocol = ethHdr.Type() + case d.flags.TUN: + // TUN interface with IFF_NO_PI enabled, thus + // we need to determine protocol from version field + version := data[0] >> 4 + if version == 4 { + protocol = header.IPv4ProtocolNumber + } else if version == 6 { + protocol = header.IPv6ProtocolNumber + } } // Try to determine remote link address, default zero. @@ -233,7 +243,7 @@ func (d *Device) Read() ([]byte, error) { endpoint := d.endpoint d.mu.RUnlock() if endpoint == nil { - return nil, syserror.EBADFD + return nil, linuxerr.EBADFD } for { @@ -264,13 +274,6 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { vv.AppendView(buffer.View(hdr)) } - // If the packet does not already have link layer header, and the route - // does not exist, we can't compute it. This is possibly a raw packet, tun - // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 { - return nil, false - } - // Ethernet header (TAP only). if d.flags.TAP { // Add ethernet header if not provided. diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go index 0591fbd63..db4338e79 100644 --- a/pkg/tcpip/link/tun/tun_unsafe.go +++ b/pkg/tcpip/link/tun/tun_unsafe.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux // Package tun contains methods to open TAP and TUN devices. diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 0b51563cd..1261ad414 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() { // Precondition: m.mu must be read locked. func (m *mockMulticastGroupProtocol) Enabled() bool { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") } @@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool { // Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } if m.mu.TryRLock() { - m.mu.RUnlock() + m.mu.RUnlock() // +checklocksforce: TryLock. m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) } @@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo // Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error { if m.mu.TryLock() { - m.mu.Unlock() + m.mu.Unlock() // +checklocksforce: TryLock. m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } if m.mu.TryRLock() { - m.mu.RUnlock() + m.mu.RUnlock() // +checklocksforce: TryLock. m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f5693defe..b1aec5312 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -344,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() defer e.mu.Unlock() - e.mu.ndp.invalidateDefaultRouter(rtr) + + // We represent default routers with a default (off-link) route through the + // router. + e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr}) } // SetNDPConfigurations implements NDPEndpoint. diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index c44e4ac4e..8837d66d8 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -54,6 +54,11 @@ const ( // Advertisements, as a host. defaultDiscoverDefaultRouters = true + // defaultDiscoverMoreSpecificRoutes is the default configuration for + // whether or not to discover more-specific routes from incoming Router + // Advertisements, as a host. + defaultDiscoverMoreSpecificRoutes = true + // defaultDiscoverOnLinkPrefixes is the default configuration for // whether or not to discover on-link prefixes from incoming Router // Advertisements' Prefix Information option, as a host. @@ -78,13 +83,13 @@ const ( // we cannot have a negative delay. minimumMaxRtrSolicitationDelay = 0 - // MaxDiscoveredDefaultRouters is the maximum number of discovered - // default routers. The stack should stop discovering new routers after - // discovering MaxDiscoveredDefaultRouters routers. + // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link + // routes. The stack should stop discovering new off-link routes after + // this limit is reached. // // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and // SHOULD be more. - MaxDiscoveredDefaultRouters = 10 + MaxDiscoveredOffLinkRoutes = 10 // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered // on-link prefixes. The stack should stop discovering new on-link @@ -352,12 +357,18 @@ type NDPConfigurations struct { // DiscoverDefaultRouters determines whether or not default routers are // discovered from Router Advertisements, as per RFC 4861 section 6. This - // configuration is ignored if HandleRAs is false. + // configuration is ignored if RAs will not be processed (see HandleRAs). DiscoverDefaultRouters bool + // DiscoverMoreSpecificRoutes determines whether or not more specific routes + // are discovered from Router Advertisements, as per RFC 4191. This + // configuration is ignored if RAs will not be processed (see HandleRAs). + DiscoverMoreSpecificRoutes bool + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes are // discovered from Router Advertisements' Prefix Information option, as per - // RFC 4861 section 6. This configuration is ignored if HandleRAs is false. + // RFC 4861 section 6. This configuration is ignored if RAs will not be + // processed (see HandleRAs). DiscoverOnLinkPrefixes bool // AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs @@ -408,6 +419,7 @@ func DefaultNDPConfigurations() NDPConfigurations { MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, HandleRAs: defaultHandleRAs, DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverMoreSpecificRoutes: defaultDiscoverMoreSpecificRoutes, DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, @@ -448,6 +460,11 @@ type timer struct { timer tcpip.Timer } +type offLinkRoute struct { + dest tcpip.Subnet + router tcpip.Address +} + // ndpState is the per-Interface NDP state. type ndpState struct { // Do not allow overwriting this state. @@ -462,8 +479,8 @@ type ndpState struct { // The DAD timers to send the next NS message, or resolve the address. dad ip.DAD - // The default routers discovered through Router Advertisements. - defaultRouters map[tcpip.Address]defaultRouterState + // The off-link routes discovered through Router Advertisements. + offLinkRoutes map[offLinkRoute]offLinkRouteState // rtrSolicitTimer is the timer used to send the next router solicitation // message. @@ -491,12 +508,12 @@ type ndpState struct { temporaryAddressDesyncFactor time.Duration } -// defaultRouterState holds data associated with a default router discovered by +// offLinkRouteState holds data associated with an off-link route discovered by // a Router Advertisement (RA). -type defaultRouterState struct { +type offLinkRouteState struct { prf header.NDPRoutePreference - // Job to invalidate the default router. + // Job to invalidate the route. // // Must not be nil. invalidationJob *tcpip.Job @@ -727,38 +744,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { prf = header.MediumRoutePreference } - rtr, ok := ndp.defaultRouters[ip] - rl := ra.RouterLifetime() - switch { - case !ok && rl != 0: - // This is a new default router we are discovering. - // - // Only remember it if we currently know about less than - // MaxDiscoveredDefaultRouters routers. - if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { - ndp.rememberDefaultRouter(ip, rl, prf) - } - - case ok && rl != 0: - // This is an already discovered default router. Update - // the invalidation job. - rtr.invalidationJob.Cancel() - rtr.invalidationJob.Schedule(rl) - - if prf != rtr.prf { - rtr.prf = prf - - // Inform the integrator about router preference updates. - ndp.ep.protocol.options.NDPDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf) - } - - ndp.defaultRouters[ip] = rtr - - case ok && rl == 0: - // We know about the router but it is no longer to be - // used as a default router so invalidate it. - ndp.invalidateDefaultRouter(ip) - } + // We represent default routers with a default (off-link) route through the + // router. + ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf) } // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter @@ -810,58 +798,107 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { if opt.AutonomousAddressConfigurationFlag() { ndp.handleAutonomousPrefixInformation(opt) } + + case header.NDPRouteInformation: + if !ndp.configs.DiscoverMoreSpecificRoutes { + continue + } + + dest, err := opt.Prefix() + if err != nil { + panic(fmt.Sprintf("%T.Prefix(): %s", opt, err)) + } + + prf := opt.RoutePreference() + if prf == header.ReservedRoutePreference { + // As per RFC 4191 section 2.3, + // + // Prf (Route Preference) + // 2-bit signed integer. The Route Preference indicates + // whether to prefer the router associated with this prefix + // over others, when multiple identical prefixes (for + // different routers) have been received. If the Reserved + // (10) value is received, the Route Information Option MUST + // be ignored. + continue + } + + ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: dest, router: ip}, opt.RouteLifetime(), prf) } // TODO(b/141556115): Do (MTU) Parameter Discovery. } } -// invalidateDefaultRouter invalidates a discovered default router. +// invalidateOffLinkRoute invalidates a discovered off-link route. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { - rtr, ok := ndp.defaultRouters[ip] - - // Is the router still discovered? +func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) { + state, ok := ndp.offLinkRoutes[route] if !ok { - // ...Nope, do nothing further. return } - rtr.invalidationJob.Cancel() - delete(ndp.defaultRouters, ip) + state.invalidationJob.Cancel() + delete(ndp.offLinkRoutes, route) - // Let the integrator know a discovered default router is invalidated. + // Let the integrator know a discovered off-link route is invalidated. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip) + ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router) } } -// rememberDefaultRouter remembers a newly discovered default router with IPv6 -// link-local address ip with lifetime rl. -// -// The router identified by ip MUST NOT already be known by the IPv6 endpoint. +// handleOffLinkRouteDiscovery handles the discovery of an off-link route. // -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration, prf header.NDPRoutePreference) { +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) { ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } - // Inform the integrator when we discovered a default router. - ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf) + state, ok := ndp.offLinkRoutes[route] + switch { + case !ok && lifetime != 0: + // This is a new route we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredOffLinkRoutes routers. + if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes { + // Inform the integrator when we discovered an off-link route. + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf) + + state := offLinkRouteState{ + prf: prf, + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + ndp.invalidateOffLinkRoute(route) + }), + } - state := defaultRouterState{ - prf: prf, - invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { - ndp.invalidateDefaultRouter(ip) - }), - } + state.invalidationJob.Schedule(lifetime) + + ndp.offLinkRoutes[route] = state + } - state.invalidationJob.Schedule(rl) + case ok && lifetime != 0: + // This is an already discovered off-link route. Update the lifetime. + state.invalidationJob.Cancel() + state.invalidationJob.Schedule(lifetime) - ndp.defaultRouters[ip] = state + if prf != state.prf { + state.prf = prf + + // Inform the integrator about route preference updates. + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf) + } + + ndp.offLinkRoutes[route] = state + + case ok && lifetime == 0: + // The already discovered off-link route is no longer considered valid so we + // invalidate it immediately. + ndp.invalidateOffLinkRoute(route) + } } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -1677,12 +1714,12 @@ func (ndp *ndpState) cleanupState() { panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) } - for router := range ndp.defaultRouters { - ndp.invalidateDefaultRouter(router) + for route := range ndp.offLinkRoutes { + ndp.invalidateOffLinkRoute(route) } - if got := len(ndp.defaultRouters); got != 0 { - panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) + if got := len(ndp.offLinkRoutes); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got)) } ndp.dhcpv6Configuration = 0 @@ -1845,14 +1882,14 @@ func (ndp *ndpState) stopSolicitingRouters() { } func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { - if ndp.defaultRouters != nil { + if ndp.offLinkRoutes != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions) - ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 130438f3b..f0186c64e 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -93,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { ipv6EP := ep.(*endpoint) ipv6EP.mu.Lock() - ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour, header.MediumRoutePreference) + ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference) ipv6EP.mu.Unlock() if ndpDisp.addr != lladdr1 { diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index b7f6d52ae..fe98a52af 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -12,6 +12,7 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 854d6a6ba..fb8ef1ee2 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -122,7 +123,7 @@ type deviceToDest map[tcpip.NICID]destToCounter // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (dd deviceToDest) isAvailable(res Reservation) bool { +func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool { flagBits := res.Flags.Bits() if res.BindToDevice == 0 { intersection := FlagMask @@ -138,6 +139,9 @@ func (dd deviceToDest) isAvailable(res Reservation) bool { return false } } + if !portSpecified && res.Transport == header.TCPProtocolNumber { + return false + } return true } @@ -146,16 +150,26 @@ func (dd deviceToDest) isAvailable(res Reservation) bool { if dests, ok := dd[0]; ok { var count int intersection, count = dests.intersectionFlags(res) - if count > 0 && intersection&flagBits == 0 { - return false + if count > 0 { + if intersection&flagBits == 0 { + return false + } + if !portSpecified && res.Transport == header.TCPProtocolNumber { + return false + } } } if dests, ok := dd[res.BindToDevice]; ok { flags, count := dests.intersectionFlags(res) intersection &= flags - if count > 0 && intersection&flagBits == 0 { - return false + if count > 0 { + if intersection&flagBits == 0 { + return false + } + if !portSpecified && res.Transport == header.TCPProtocolNumber { + return false + } } } @@ -168,12 +182,12 @@ type addrToDevice map[tcpip.Address]deviceToDest // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (ad addrToDevice) isAvailable(res Reservation) bool { +func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool { if res.Addr == anyIPAddress { // If binding to the "any" address then check that there are no // conflicts with all addresses. for _, devices := range ad { - if !devices.isAvailable(res) { + if !devices.isAvailable(res, portSpecified) { return false } } @@ -182,14 +196,14 @@ func (ad addrToDevice) isAvailable(res Reservation) bool { // Check that there is no conflict with the "any" address. if devices, ok := ad[anyIPAddress]; ok { - if !devices.isAvailable(res) { + if !devices.isAvailable(res, portSpecified) { return false } } // Check that this is no conflict with the provided address. if devices, ok := ad[res.Addr]; ok { - if !devices.isAvailable(res) { + if !devices.isAvailable(res, portSpecified) { return false } } @@ -310,7 +324,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por // If a port is specified, just try to reserve it for all network // protocols. if res.Port != 0 { - if !pm.reserveSpecificPortLocked(res) { + if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) { return 0, &tcpip.ErrPortInUse{} } if testPort != nil { @@ -330,7 +344,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por // A port wasn't specified, so try to find one. return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) { res.Port = p - if !pm.reserveSpecificPortLocked(res) { + if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) { return false, nil } if testPort != nil { @@ -350,12 +364,12 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por // reserveSpecificPortLocked tries to reserve the given port on all given // protocols. -func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool { +func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool { // Make sure the port is available. for _, network := range res.Networks { desc := portDescriptor{network, res.Transport, res.Port} if addrs, ok := pm.allocatedPorts[desc]; ok { - if !addrs.isAvailable(res) { + if !addrs.isAvailable(res, portSpecified) { return false } } diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index b9a24ff56..009cab643 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux // This sample creates a stack with TCP and IPv4 protocols on top of a TUN diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index ef1bfc186..c10b19aa0 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +//go:build linux // +build linux // This sample creates a stack with TCP and IPv4 protocols on top of a TUN diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index 0ea85f9ed..5642c86f8 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -15,17 +15,11 @@ package tcpip import ( - "math" "sync/atomic" - "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sync" ) -// PacketOverheadFactor is used to multiply the value provided by the user on a -// SetSockOpt for setting the send/receive buffer sizes sockets. -const PacketOverheadFactor = 2 - // SocketOptionsHandler holds methods that help define endpoint specific // behavior for socket level socket options. These must be implemented by // endpoints to get notified when socket level options are set. @@ -60,7 +54,7 @@ type SocketOptionsHandler interface { // buffer size. It also returns the newly set value. OnSetSendBufferSize(v int64) (newSz int64) - // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE. + // OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE. OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) } @@ -213,16 +207,24 @@ type SocketOptions struct { // will not change. getSendBufferLimits GetSendBufferLimits `state:"manual"` + // sendBufSizeMu protects sendBufferSize and calls to + // handler.OnSetSendBufferSize. + sendBufSizeMu sync.Mutex `state:"nosave"` + // sendBufferSize determines the send buffer size for this socket. - sendBufferSize atomicbitops.AlignedAtomicInt64 + sendBufferSize int64 // getReceiveBufferLimits provides the handler to get the min, default and // max size for receive buffer. It is initialized at the creation time and // will not change. getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` + // receiveBufSizeMu protects receiveBufferSize and calls to + // handler.OnSetReceiveBufferSize. + receiveBufSizeMu sync.Mutex `state:"nosave"` + // receiveBufferSize determines the receive buffer size for this socket. - receiveBufferSize atomicbitops.AlignedAtomicInt64 + receiveBufferSize int64 // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -612,81 +614,52 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { return nil } +// SendBufferLimits returns the [min, max) range of allowable send buffer +// sizes. +func (so *SocketOptions) SendBufferLimits() (min, max int64) { + limits := so.getSendBufferLimits(so.stackHandler) + return int64(limits.Min), int64(limits.Max) +} + // GetSendBufferSize gets value for SO_SNDBUF option. func (so *SocketOptions) GetSendBufferSize() int64 { - return so.sendBufferSize.Load() + so.sendBufSizeMu.Lock() + defer so.sendBufSizeMu.Unlock() + return so.sendBufferSize } // SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the // stack handler should be invoked to set the send buffer size. func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { - v := sendBufferSize - - if !notify { - so.sendBufferSize.Store(v) - return - } - - // Make sure the send buffer size is within the min and max - // allowed. - ss := so.getSendBufferLimits(so.stackHandler) - min := int64(ss.Min) - max := int64(ss.Max) - // Validate the send buffer size with min and max values. - // Multiply it by factor of 2. - if v > max { - v = max - } - - if v < math.MaxInt32/PacketOverheadFactor { - v *= PacketOverheadFactor - if v < min { - v = min - } - } else { - v = math.MaxInt32 + so.sendBufSizeMu.Lock() + defer so.sendBufSizeMu.Unlock() + if notify { + sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize) } + so.sendBufferSize = sendBufferSize +} - // Notify endpoint about change in buffer size. - newSz := so.handler.OnSetSendBufferSize(v) - so.sendBufferSize.Store(newSz) +// ReceiveBufferLimits returns the [min, max) range of allowable receive buffer +// sizes. +func (so *SocketOptions) ReceiveBufferLimits() (min, max int64) { + limits := so.getReceiveBufferLimits(so.stackHandler) + return int64(limits.Min), int64(limits.Max) } // GetReceiveBufferSize gets value for SO_RCVBUF option. func (so *SocketOptions) GetReceiveBufferSize() int64 { - return so.receiveBufferSize.Load() + so.receiveBufSizeMu.Lock() + defer so.receiveBufSizeMu.Unlock() + return so.receiveBufferSize } -// SetReceiveBufferSize sets value for SO_RCVBUF option. +// SetReceiveBufferSize sets the value of the SO_RCVBUF option, optionally +// notifying the owning endpoint. func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { - if !notify { - so.receiveBufferSize.Store(receiveBufferSize) - return - } - - // Make sure the send buffer size is within the min and max - // allowed. - v := receiveBufferSize - ss := so.getReceiveBufferLimits(so.stackHandler) - min := int64(ss.Min) - max := int64(ss.Max) - // Validate the send buffer size with min and max values. - if v > max { - v = max - } - - // Multiply it by factor of 2. - if v < math.MaxInt32/PacketOverheadFactor { - v *= PacketOverheadFactor - if v < min { - v = min - } - } else { - v = math.MaxInt32 + so.receiveBufSizeMu.Lock() + defer so.receiveBufSizeMu.Unlock() + if notify { + receiveBufferSize = so.handler.OnSetReceiveBufferSize(receiveBufferSize, so.receiveBufferSize) } - - oldSz := so.receiveBufferSize.Load() - // Notify endpoint about change in buffer size. - newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) - so.receiveBufferSize.Store(newSz) + so.receiveBufferSize = receiveBufferSize } diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index ce9cebdaa..ae0bb4ace 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address // or we are adding a new temporary or permanent address. // // The address MUST be write locked at this point. - defer addrState.mu.Unlock() + defer addrState.mu.Unlock() // +checklocksforce if permanent { if addrState.mu.kind.IsPermanent() { diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 18e0d4374..068dab7ce 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -363,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } } @@ -405,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. + var newAddr tcpip.Address + var newPort uint16 + + updateSRCFields := false + switch hook { case Prerouting, Output: if conn.manip == manipDestination { switch dir { case dirOriginal: - tcpHeader.SetDestinationPort(conn.reply.srcPort) - netHeader.SetDestinationAddress(conn.reply.srcAddr) + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr case dirReply: - tcpHeader.SetSourcePort(conn.original.dstPort) - netHeader.SetSourceAddress(conn.original.dstAddr) + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + + updateSRCFields = true } pkt.NatDone = true } @@ -422,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if conn.manip == manipSource { switch dir { case dirOriginal: - tcpHeader.SetSourcePort(conn.reply.dstPort) - netHeader.SetSourceAddress(conn.reply.dstAddr) + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + + updateSRCFields = true case dirReply: - tcpHeader.SetDestinationPort(conn.original.srcPort) - netHeader.SetDestinationAddress(conn.original.srcAddr) + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr } pkt.NatDone = true } @@ -437,29 +446,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } + fullChecksum := false + updatePseudoHeader := false switch hook { case Prerouting, Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - tcpHeader.SetChecksum(xsum) + updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + fullChecksum = true + updatePseudoHeader = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } + rewritePacket( + netHeader, + tcpHeader, + updateSRCFields, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) // Update the state of tcb. conn.mu.Lock() @@ -615,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo // Don't re-unlock if both tuples are in the same bucket. if differentBuckets { - ct.buckets[replyBucket].mu.Unlock() + ct.buckets[replyBucket].mu.Unlock() // +checklocksforce } return true diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 91e266de8..96cc899bb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -133,29 +133,23 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.Port) - // Calculate UDP checksum and set it. if hook == Output { - udpHeader.SetChecksum(0) - netHeader := pkt.Network() - netHeader.SetDestinationAddress(address) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + udpHeader, + false, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + rt.Port, + address, + ) + } else { + udpHeader.SetDestinationPort(rt.Port) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -214,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetChecksum(0) - udpHeader.SetSourcePort(st.Port) - netHeader := pkt.Network() - netHeader.SetSourceAddress(st.Addr) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + header.UDP(pkt.TransportHeader().View()), + true, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + st.Port, + st.Addr, + ) - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -252,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleAccept, 0 } + +func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { + if updateSRCFields { + if fullChecksum { + t.SetSourcePortWithChecksumUpdate(newPort) + } else { + t.SetSourcePort(newPort) + } + } else { + if fullChecksum { + t.SetDestinationPortWithChecksumUpdate(newPort) + } else { + t.SetDestinationPort(newPort) + } + } + + if updatePseudoHeader { + var oldAddr tcpip.Address + if updateSRCFields { + oldAddr = n.SourceAddress() + } else { + oldAddr = n.DestinationAddress() + } + + t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum) + } + + if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok { + if updateSRCFields { + checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr) + } else { + checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr) + } + } else if updateSRCFields { + n.SetSourceAddress(newAddr) + } else { + n.SetDestinationAddress(newAddr) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index bd512d312..4d5431da1 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1152,6 +1152,39 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } +// raBufWithRIO returns a valid NDP Router Advertisement with a single Route +// Information option. +// +// All fields in the RA will be zero except the RIO option. +func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer { + // buf will hold the route information option after the Type and Length + // fields. + // + // 2.3. Route Information Option + // + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Prefix Length |Resvd|Prf|Resvd| + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Route Lifetime | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Prefix (Variable Length) | + // . . + // . . + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + var buf [22]byte + buf[0] = uint8(prefix.PrefixLen) + buf[1] = byte(prf) << 3 + binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds) + if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) { + t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address)) + } + return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{ + header.NDPRouteInformation(buf[:]), + }) +} + func TestDynamicConfigurationsDisabled(t *testing.T) { const ( nicID = 1 @@ -1308,8 +1341,8 @@ func boolToUint64(v bool) uint64 { return 0 } -func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { - return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1342,126 +1375,171 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -func TestRouterDiscovery(t *testing.T) { +func TestOffLinkRouteDiscovery(t *testing.T) { const nicID = 1 - testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { - ndpDisp := ndpDispatcher{ - offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - clock := faketime.NewManualClock() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handleRAs, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - Clock: clock, - }) + moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16} + tests := []struct { + name string - expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { - t.Helper() + discoverDefaultRouters bool + discoverMoreSpecificRoutes bool - select { - case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) + dest tcpip.Subnet + ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer + }{ + { + name: "Default router discovery", + discoverDefaultRouters: true, + discoverMoreSpecificRoutes: false, + dest: header.IPv6EmptySubnet, + ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBufWithPrf(router, lifetimeSeconds, prf) + }, + }, + { + name: "More-specific route discovery", + discoverDefaultRouters: false, + discoverMoreSpecificRoutes: true, + dest: moreSpecificPrefix.Subnet(), + ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } - default: - t.Fatal("expected router discovery event") - } - } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: test.discoverDefaultRouters, + DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { + t.Helper() - clock.Advance(timeout) - select { - case e := <-ndpDisp.offLinkRouteC: - var prf header.NDPRoutePreference - if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" { - t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") + } } - default: - t.Fatal("timed out waiting for router discovery event") - } - } - if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { - t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) - } + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID, err) - } + clock.Advance(timeout) + select { + case e := <-ndpDisp.offLinkRouteC: + var prf header.NDPRoutePreference + if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("timed out waiting for router discovery event") + } + } - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("unexpectedly updated an off-link route with 0 lifetime") - default: - } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from lladdr2 with a huge lifetime and reserved preference value - // (which should be interpreted as the default (medium) preference value). - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - - // Rx an RA from another router (lladdr3) with non-zero lifetime and - // non-default preference value. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) - expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - - // Rx an RA from lladdr2 with lesser lifetime and default (medium) - // preference value. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.offLinkRouteC: - t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") - default: - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - // Rx an RA from lladdr2 with a different preference. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) - expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) - - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) - - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) - expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) - - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) - }) + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference)) + select { + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") + default: + } + + // Discover an off-link route through llAddr2. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference)) + if test.discoverMoreSpecificRoutes { + // The reserved value is considered invalid with more-specific route + // discovery so we inject the same packet but with the default + // (medium) preference value. + select { + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with a reserved preference value") + default: + } + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference)) + } + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) + + // Rx an RA from another router (lladdr3) with non-zero lifetime and + // non-default preference value. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) + expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) + + // Rx an RA from lladdr2 with lesser lifetime and default (medium) + // preference value. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference)) + select { + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") + default: + } + + // Rx an RA from lladdr2 with a different preference. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) + + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) + }) + }) + } } // TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { const nicID = 1 @@ -1484,17 +1562,17 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - if i <= ipv6.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredOffLinkRoutes { select { case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" { + if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" { t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: @@ -4583,7 +4661,7 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) select { case e := <-ndpDisp.offLinkRouteC: - if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { + if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: @@ -5278,8 +5356,9 @@ func TestRouterSolicitation(t *testing.T) { RandSource: &randSource, }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + opts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, &e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) } if addr := test.nicAddr; addr != "" { @@ -5288,6 +5367,10 @@ func TestRouterSolicitation(t *testing.T) { } } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining != 0 { diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index e90c1a770..90a8ba6cf 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -380,9 +380,6 @@ type TCPSndBufState struct { // SndClosed indicates that the endpoint has been closed for sends. SndClosed bool - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - // PacketTooBigCount is used to notify the main protocol routine how // many times a "packet too big" control packet is received. PacketTooBigCount int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index cb316d27a..f9a15efb2 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. +// +checklocks:e.mu func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.state { case stateInitial: @@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip } e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() + defer e.mu.DowngradeLock() // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index b6687911a..ab5da987a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -132,7 +132,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.ops.SetReceiveBufferSize(0, false) + e.ops.SetReceiveBufferSize(0, false /* notify */) e.waiterQueue = nil return e, nil } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d807b13b7..aa413ad05 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -330,7 +330,9 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, } ep := h.ep - if err := h.complete(); err != nil { + // N.B. the endpoint is generated above by startHandshake, and will be + // returned locked. This first call is forced. + if err := h.complete(); err != nil { // +checklocksforce ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() ep.stats.FailedConnectionAttempts.Increment() l.cleanupFailedHandshake(h) @@ -364,6 +366,7 @@ func (l *listenContext) closeAllPendingEndpoints() { } // Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() @@ -504,7 +507,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - if err := h.complete(); err != nil { + // Note that startHandshake returns a locked endpoint. The + // force call here just makes it so. + if err := h.complete(); err != nil { // +checklocksforce e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index f86450298..93ed161f9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -511,6 +511,7 @@ func (h *handshake) start() { } // complete completes the TCP 3-way handshake initiated by h.start(). +// +checklocks:h.ep.mu func (h *handshake) complete() tcpip.Error { // Set up the wakers. var s sleep.Sleeper @@ -909,30 +910,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se return err } -func (e *endpoint) handleWrite() { - e.sndQueueInfo.sndQueueMu.Lock() - next := e.drainSendQueueLocked() - e.sndQueueInfo.sndQueueMu.Unlock() - - e.sendData(next) -} - -// Move packets from send queue to send list. -// -// Precondition: e.sndBufMu must be locked. -func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueueInfo.sndQueue.Front() - if first != nil { - e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) - e.sndQueueInfo.SndBufInQueue = 0 - } - return first -} - // Precondition: e.mu must be locked. func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { + if next == nil { + return + } e.snd.writeNext = next } @@ -940,17 +924,6 @@ func (e *endpoint) sendData(next *segment) { e.snd.sendData() } -func (e *endpoint) handleClose() { - if !e.EndpointState().connected() { - return - } - // Drain the send queue. - e.handleWrite() - - // Mark send side as closed. - e.snd.Closed = true -} - // resetConnectionLocked puts the endpoint in an error state with the given // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This @@ -1311,42 +1284,45 @@ func (e *endpoint) disableKeepaliveTimer() { e.keepalive.Unlock() } -// protocolMainLoop is the main loop of the TCP protocol. It runs in its own -// goroutine and is responsible for sending segments and handling received -// segments. -func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { - e.mu.Lock() - var closeTimer tcpip.Timer - var closeWaker sleep.Waker - - epilogue := func() { - // e.mu is expected to be hold upon entering this section. - if e.snd != nil { - e.snd.resendTimer.cleanup() - e.snd.probeTimer.cleanup() - e.snd.reorderTimer.cleanup() - } +// protocolMainLoopDone is called at the end of protocolMainLoop. +// +checklocksrelease:e.mu +func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) { + if e.snd != nil { + e.snd.resendTimer.cleanup() + e.snd.probeTimer.cleanup() + e.snd.reorderTimer.cleanup() + } - if closeTimer != nil { - closeTimer.Stop() - } + if closeTimer != nil { + closeTimer.Stop() + } - e.completeWorkerLocked() + e.completeWorkerLocked() - if e.drainDone != nil { - close(e.drainDone) - } + if e.drainDone != nil { + close(e.drainDone) + } - e.mu.Unlock() + e.mu.Unlock() - e.drainClosingSegmentQueue() + e.drainClosingSegmentQueue() - // When the protocol loop exits we should wake up our waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) - } + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) +} + +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { + var ( + closeTimer tcpip.Timer + closeWaker sleep.Waker + ) + e.mu.Lock() if handshake { - if err := e.h.complete(); err != nil { + if err := e.h.complete(); err != nil { // +checklocksforce e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1355,8 +1331,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.hardError = err e.workerCleanup = true - // Lock released below. - epilogue() + e.protocolMainLoopDone(closeTimer, &closeWaker) return err } } @@ -1402,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { - e.handleWrite() - return nil - }, - }, - { - w: &e.sndQueueInfo.sndCloseWaker, - f: func() tcpip.Error { - e.handleClose() + e.sendData(nil /* next */) return nil }, }, @@ -1507,7 +1475,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) - e.mu.Unlock() + e.mu.Unlock() // +checklocksforce <-e.undrain e.mu.Lock() } @@ -1568,8 +1536,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if err != nil { e.resetConnectionLocked(err) } - // Lock released below. - epilogue() } loop: @@ -1593,6 +1559,7 @@ loop: // just want to terminate the loop and cleanup the // endpoint. cleanupOnError(nil) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil case StateTimeWait: fallthrough @@ -1601,6 +1568,7 @@ loop: default: if err := funcs[v].f(); err != nil { cleanupOnError(err) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil } } @@ -1624,13 +1592,13 @@ loop: // Handle any StateError transition from StateTimeWait. if e.EndpointState() == StateError { cleanupOnError(nil) + e.protocolMainLoopDone(closeTimer, &closeWaker) return nil } e.transitionToStateCloseLocked() - // Lock released below. - epilogue() + e.protocolMainLoopDone(closeTimer, &closeWaker) // A new SYN was received during TIME_WAIT and we need to abort // the timewait and redirect the segment to the listener queue @@ -1700,6 +1668,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() // should be executed after releasing the endpoint registrations. This is // done in cases where a new SYN is received during TIME_WAIT that carries // a sequence number larger than one see on the connection. +// +checklocks:e.mu func (e *endpoint) doTimeWait() (twReuse func()) { // Trigger a 2 * MSL time wait state. During this period // we will drop all incoming segments. diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index dff7cb89c..7d110516b 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) { case !ep.segmentQueue.empty(): p.epQ.enqueue(ep) } - ep.mu.Unlock() + ep.mu.Unlock() // +checklocksforce } else { ep.newSegmentWaker.Assert() } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a27e2110b..ebc88d6c3 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -293,16 +293,9 @@ type sndQueueInfo struct { sndQueueMu sync.Mutex `state:"nosave"` stack.TCPSndBufState - // sndQueue holds segments that are ready to be sent. - sndQueue segmentList `state:"wait"` - - // sndWaker is used to signal the protocol goroutine when segments are - // added to the `sndQueue`. + // sndWaker is used to signal the protocol goroutine when there may be + // segments that need to be sent. sndWaker sleep.Waker `state:"manual"` - - // sndCloseWaker is used to notify the protocol goroutine when the send - // side is closed. - sndCloseWaker sleep.Waker `state:"manual"` } // rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. @@ -671,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The assumption behind spinning here being that background packet processing // should not be holding the lock for long and spinning reduces latency as we // avoid an expensive sleep/wakeup of of the syscall goroutine). +// +checklocksacquire:e.mu func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned @@ -690,7 +684,7 @@ func (e *endpoint) LockUser() { continue } atomic.StoreUint32(&e.ownedByUser, 1) - return + return // +checklocksforce } } @@ -707,7 +701,7 @@ func (e *endpoint) LockUser() { // protocol goroutine altogether. // // Precondition: e.LockUser() must have been called before calling e.UnlockUser() -// +checklocks:e.mu +// +checklocksrelease:e.mu func (e *endpoint) UnlockUser() { // Lock segment queue before checking so that we avoid a race where // segments can be queued between the time we check if queue is empty @@ -743,12 +737,13 @@ func (e *endpoint) UnlockUser() { } // StopWork halts packet processing. Only to be used in tests. +// +checklocksacquire:e.mu func (e *endpoint) StopWork() { e.mu.Lock() } // ResumeWork resumes packet processing. Only to be used in tests. -// +checklocks:e.mu +// +checklocksrelease:e.mu func (e *endpoint) ResumeWork() { e.mu.Unlock() } @@ -759,7 +754,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32(&e.state)) + oldstate := EndpointState(atomic.SwapUint32(&e.state, uint32(state))) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -776,7 +771,6 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. @@ -1487,87 +1481,101 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { return avail, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. +// readFromPayloader reads a slice from the Payloader. +// +checklocks:e.mu +// +checklocks:e.sndQueueInfo.sndQueueMu +func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) { + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() - e.LockUser() - defer e.UnlockUser() + e.UnlockUser() + defer e.LockUser() + } - nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndQueueInfo.sndQueueMu.Lock() - defer e.sndQueueInfo.sndQueueMu.Unlock() + // Fetch data. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return nil, nil + } + v := make([]byte, avail) + n, err := p.Read(v) + if err != nil && err != io.EOF { + return nil, &tcpip.ErrBadBuffer{} + } + return v[:n], nil +} + +// queueSegment reads data from the payloader and returns a segment to be sent. +// +checklocks:e.mu +func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) { + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.stats.WriteErrors.WriteClosed.Increment() + return nil, 0, err + } + + v, err := e.readFromPayloader(p, opts, avail) + if err != nil { + return nil, 0, err + } + + // Do not queue zero length segments. + if len(v) == 0 { + return nil, 0, nil + } + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. avail, err := e.isEndpointWritableLocked() if err != nil { e.stats.WriteErrors.WriteClosed.Increment() return nil, 0, err } - v, err := func() ([]byte, tcpip.Error) { - // We can release locks while copying data. - // - // This is not possible if atomic is set, because we can't allow the - // available buffer space to be consumed by some other caller while we - // are copying data in. - if !opts.Atomic { - e.sndQueueInfo.sndQueueMu.Unlock() - defer e.sndQueueInfo.sndQueueMu.Lock() - - e.UnlockUser() - defer e.LockUser() - } - - // Fetch data. - if l := p.Len(); l < avail { - avail = l - } - if avail == 0 { - return nil, nil - } - v := make([]byte, avail) - n, err := p.Read(v) - if err != nil && err != io.EOF { - return nil, &tcpip.ErrBadBuffer{} - } - return v[:n], nil - }() - if len(v) == 0 || err != nil { - return nil, 0, err + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } + } - if !opts.Atomic { - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - avail, err := e.isEndpointWritableLocked() - if err != nil { - e.stats.WriteErrors.WriteClosed.Increment() - return nil, 0, err - } + // Add data to the send queue. + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) + e.sndQueueInfo.SndBufUsed += len(v) + e.snd.writeList.PushBack(s) - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - } + return s, len(v), nil +} - // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) - e.sndQueueInfo.SndBufUsed += len(v) - e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) - e.sndQueueInfo.sndQueue.PushBack(s) +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.LockUser() + defer e.UnlockUser() - return e.drainSendQueueLocked(), len(v), nil - }() // Return if either we didn't queue anything or if an error occurred while // attempting to queue data. + nextSeg, n, err := e.queueSegment(p, opts) if n == 0 || err != nil { return 0, err } + e.sendData(nextSeg) return int64(n), nil } @@ -2314,7 +2322,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.TransportEndpointInfo.ID e.sndQueueInfo.sndWaker.Assert() @@ -2372,6 +2380,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.notifyProtocolGoroutine(notifyTickleWorker) return nil } + // Wake up any readers that maybe waiting for the stream to become + // readable. + e.waiterQueue.Notify(waiter.ReadableEvents) } // Close for write. @@ -2388,12 +2399,20 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Queue fin segment. s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil) - e.sndQueueInfo.sndQueue.PushBack(s) - e.sndQueueInfo.SndBufInQueue++ + e.snd.writeList.PushBack(s) // Mark endpoint as closed. e.sndQueueInfo.SndClosed = true e.sndQueueInfo.sndQueueMu.Unlock() - e.handleClose() + + // Drain the send queue. + e.sendData(s) + + // Mark send side as closed. + e.snd.Closed = true + + // Wake up any writers that maybe waiting for the stream to become + // writable. + e.waiterQueue.Notify(waiter.WritableEvents) } return nil @@ -2501,6 +2520,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. +// +checklocksrelease:e.mu func (e *endpoint) startAcceptedLoop() { e.workerRunning = true e.mu.Unlock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 65c86823a..2e709ed78 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, return nil, err } - // Start the protocol goroutine. - ep.startAcceptedLoop() + // Start the protocol goroutine. Note that the endpoint is returned + // from performHandshake locked. + ep.startAcceptedLoop() // +checklocksforce return ep, nil } diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index ced3a9c58..84fb1c416 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -16,6 +16,7 @@ // iterations taking long enough that the retransmit timer can kick in causing // the congestion window measurements to fail due to extra packets etc. // +//go:build !race // +build !race package tcp_test diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9bbe9bc3e..031f01357 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2147,7 +2147,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. - c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) + c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */) // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality @@ -2267,7 +2267,7 @@ func TestNoWindowShrinking(t *testing.T) { initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. - c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) + c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */) data := generateRandomPayload(t, rcvBufSize) // Send a payload of half the size of rcvBufSize. @@ -2523,7 +2523,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*3, true) + ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2595,7 +2595,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - ep.SocketOptions().SetReceiveBufferSize(65535*3, true) + ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3188,7 +3188,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) + ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3327,7 +3327,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 3 - c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) + c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */) // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) @@ -3451,17 +3451,13 @@ loop: for { switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { case *tcpip.ErrWouldBlock: - select { - case <-ch: - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") + <-ch + // Expect the state to be StateError and subsequent Reads to fail with HardError. + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } + break loop case *tcpip.ErrConnectionReset: break loop default: @@ -3472,14 +3468,27 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) + + checkValid := func() []error { + var errors []error + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) + } + return errors } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + for _, err := range checkValid() { + t.Error(err) } } @@ -3615,6 +3624,38 @@ func TestMaxRTO(t *testing.T) { } } +// TestZeroSizedWriteRetransmit tests that a zero sized write should not +// result in a panic on an RTO as no segment should have been queued for +// a zero sized write. +func TestZeroSizedWriteRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + var r bytes.Reader + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + // Now do a non-zero sized write to trigger actual sending of data. + r.Reset(make([]byte, 1)) + _, err = c.EP.Write(&r, tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + // Do not ACK the packet and expect an original transmit and a + // retransmit. This should not cause a panic. + for i := 0; i < 2; i++ { + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), + ), + ) + } +} + // TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is // unique on retransmits. func TestRetransmitIPv4IDUniqueness(t *testing.T) { @@ -4628,52 +4669,6 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) } -func TestMinMaxBufferSizes(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, - }) - - // Check the default values. - ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) - if err != nil { - t.Fatalf("NewEndpoint failed; %s", err) - } - defer ep.Close() - - // Change the min/max values for send/receive - { - opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - { - opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) - } - } - - // Set values below the min/2. - ep.SocketOptions().SetReceiveBufferSize(99, true) - checkRecvBufferSize(t, ep, 200) - - ep.SocketOptions().SetSendBufferSize(149, true) - - checkSendBufferSize(t, ep, 300) - - // Set values above the max. - ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) - // Values above max are capped at max and then doubled. - checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) - - ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - // Values above max are capped at max and then doubled. - checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) -} - func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, @@ -6068,6 +6063,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // complete the connection to test that the large SEQ num // did not change the state from SYN-RCVD. + // Get setup to be notified about connection establishment. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.ReadableEvents) + defer c.WQ.EventUnregister(&we) + // Send ACK to move to ESTABLISHED state. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -6078,32 +6078,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) + <-ch newEP, _, err := c.EP.Accept(nil) - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: + if err != nil { t.Fatalf("Accept failed: %s", err) } - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" var r strings.Reader @@ -6209,12 +6189,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { RcvWnd: 30000, }) - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) + checkValid := func() []error { + var errors []error + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) + } + return errors } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) + } + for _, err := range checkValid() { + t.Error(err) + } + if t.Failed() { + t.FailNow() } we, ch := waiter.NewChannelEntry(nil) @@ -6225,15 +6219,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { _, _, err = c.EP.Accept(nil) if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + <-ch + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) } } } @@ -7483,7 +7472,7 @@ func TestTCPUserTimeout(t *testing.T) { select { case <-notifyCh: case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) } // No packet should be received as the connection should be silently @@ -7717,7 +7706,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS - c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) + c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 53efecc5a..96e4849d2 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -757,7 +757,7 @@ func (c *Context) Create(epRcvBuf int) { } if epRcvBuf != -1 { - c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */) + c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf)*2, true /* notify */) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index def9d7186..82a3f2287 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -364,6 +364,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // reacquire the mutex in exclusive mode. // // Returns true for retry if preparation should be retried. +// +checklocks:e.mu func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { switch e.EndpointState() { case StateInitial: @@ -380,10 +381,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip } e.mu.RUnlock() - defer e.mu.RLock() - e.mu.Lock() - defer e.mu.Unlock() + defer e.mu.DowngradeLock() // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. @@ -449,37 +448,20 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - if err := e.LastError(); err != nil { - return 0, err - } - - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { e.mu.RLock() - lockReleased := false - defer func() { - if lockReleased { - return - } - e.mu.RUnlock() - }() + defer e.mu.RUnlock() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} + return udpPacketInfo{}, &tcpip.ErrClosedForSend{} } // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWrite(opts.To) if err != nil { - return 0, err + return udpPacketInfo{}, err } if !retry { @@ -489,34 +471,34 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp route := e.route dstPort := e.dstPort - if to != nil { + if opts.To != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicID := to.NIC + nicID := opts.To.NIC if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} + return udpPacketInfo{}, &tcpip.ErrNoRoute{} } nicID = e.BindNICID } - if to.Port == 0 { + if opts.To.Port == 0 { // Port 0 is an invalid port to send to. - return 0, &tcpip.ErrInvalidEndpointState{} + return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} } - dst, netProto, err := e.checkV4MappedLocked(*to) + dst, netProto, err := e.checkV4MappedLocked(*opts.To) if err != nil { - return 0, err + return udpPacketInfo{}, err } r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { - return 0, err + return udpPacketInfo{}, err } defer r.Release() @@ -525,12 +507,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { - return 0, &tcpip.ErrBroadcastDisabled{} + return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{} } v := make([]byte, p.Len()) if _, err := io.ReadFull(p, v); err != nil { - return 0, &tcpip.ErrBadBuffer{} + return udpPacketInfo{}, &tcpip.ErrBadBuffer{} } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. @@ -548,24 +530,39 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp v, ) } - return 0, &tcpip.ErrMessageTooLong{} + return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} } ttl := e.ttl useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { ttl = e.multicastTTL // Multicast allows a 0 TTL. useDefaultTTL = false } - localPort := e.ID.LocalPort - sendTOS := e.sendTOS - owner := e.owner - noChecksum := e.SocketOptions().GetNoChecksum() - lockReleased = true - e.mu.RUnlock() + return udpPacketInfo{ + route: route, + data: buffer.View(v), + localPort: e.ID.LocalPort, + remotePort: dstPort, + ttl: ttl, + useDefaultTTL: useDefaultTTL, + tos: e.sendTOS, + owner: e.owner, + noChecksum: e.SocketOptions().GetNoChecksum(), + }, nil +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + if err := e.LastError(); err != nil { + return 0, err + } + + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, &tcpip.ErrInvalidOptionValue{} + } // Do not hold lock when sending as loopback is synchronous and if the UDP // datagram ends up generating an ICMP response then it can result in a @@ -577,10 +574,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read // locking is prohibited. - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { + u, err := e.buildUDPPacketInfo(p, opts) + if err != nil { return 0, err } - return int64(len(v)), nil + n, err := u.send() + if err != nil { + return 0, err + } + return int64(n), nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler. @@ -817,14 +819,30 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { return nil } -// sendUDP sends a UDP segment via the provided network endpoint and under the -// provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error { +// udpPacketInfo contains all information required to send a UDP packet. +// +// This should be used as a value-only type, which exists in order to simplify +// return value syntax. It should not be exported or extended. +type udpPacketInfo struct { + route *stack.Route + data buffer.View + localPort uint16 + remotePort uint16 + ttl uint8 + useDefaultTTL bool + tos uint8 + owner tcpip.PacketOwner + noChecksum bool +} + +// send sends the given packet. +func (u *udpPacketInfo) send() (int, tcpip.Error) { + vv := u.data.ToVectorisedView() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), - Data: data, + ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()), + Data: vv, }) - pkt.Owner = owner + pkt.Owner = u.owner // Initialize the UDP header. udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) @@ -832,8 +850,8 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ - SrcPort: localPort, - DstPort: remotePort, + SrcPort: u.localPort, + DstPort: u.remotePort, Length: length, }) @@ -841,30 +859,30 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.RequiresTXTransportChecksum() && - (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) - for _, v := range data.Views() { + if u.route.RequiresTXTransportChecksum() && + (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) { + xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length) + for _, v := range vv.Views() { xsum = header.Checksum(v, xsum) } udp.SetChecksum(^udp.CalculateChecksum(xsum)) } - if useDefaultTTL { - ttl = r.DefaultTTL() + if u.useDefaultTTL { + u.ttl = u.route.DefaultTTL() } - if err := r.WritePacket(stack.NetworkHeaderParams{ + if err := u.route.WritePacket(stack.NetworkHeaderParams{ Protocol: ProtocolNumber, - TTL: ttl, - TOS: tos, + TTL: u.ttl, + TOS: u.tos, }, pkt); err != nil { - r.Stats().UDP.PacketSendErrors.Increment() - return err + u.route.Stats().UDP.PacketSendErrors.Increment() + return 0, err } // Track count of packets sent. - r.Stats().UDP.PacketsSent.Increment() - return nil + u.route.Stats().UDP.PacketsSent.Increment() + return len(u.data), nil } // checkV4MappedLocked determines the effective network protocol and converts |