summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/header/checksum.go62
-rw-r--r--pkg/tcpip/header/checksum_test.go203
-rw-r--r--pkg/tcpip/header/interfaces.go38
-rw-r--r--pkg/tcpip/header/ipv4.go12
-rw-r--r--pkg/tcpip/header/ndp_options.go145
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go19
-rw-r--r--pkg/tcpip/header/ndp_test.go248
-rw-r--r--pkg/tcpip/header/tcp.go29
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go208
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go1
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go1
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go1
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go1
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go1
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go1
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go1
-rw-r--r--pkg/tcpip/link/rawfile/errors.go1
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go1
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go62
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go1
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go1
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go5
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go39
-rw-r--r--pkg/tcpip/link/tun/BUILD1
-rw-r--r--pkg/tcpip/link/tun/device.go31
-rw-r--r--pkg/tcpip/link/tun/tun_unsafe.go1
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go10
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go5
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go187
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go2
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go40
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go1
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go1
-rw-r--r--pkg/tcpip/socketops.go111
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go2
-rw-r--r--pkg/tcpip/stack/conntrack.go55
-rw-r--r--pkg/tcpip/stack/iptables_targets.go97
-rw-r--r--pkg/tcpip/stack/ndp_test.go307
-rw-r--r--pkg/tcpip/stack/tcp.go3
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go117
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go180
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go5
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go1
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go201
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go150
56 files changed, 1859 insertions, 784 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index ed4d7e958..f00cfd0f5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -46,7 +46,6 @@ deps_test(
"//pkg/gohacks",
"//pkg/goid",
"//pkg/ilist",
- "//pkg/iovec",
"//pkg/linewriter",
"//pkg/log",
"//pkg/rand",
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 6aa9acfa8..e2c85e220 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,6 +18,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
+
+// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
+// checksum.
+//
+// The value MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
+}
+
+// checksumUpdate2ByteAlignedAddress updates an address in a calculated
+// checksum.
+//
+// The addresses must have the same length and must contain an even number
+// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
+ const uint16Bytes = 2
+
+ if len(old) != len(new) {
+ panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
+ }
+
+ if len(old)%uint16Bytes != 0 {
+ panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
+ }
+
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ for len(old) != 0 {
+ // Convert the 2 byte sequences to uint16 values then apply the increment
+ // update.
+ xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
+ old = old[uint16Bytes:]
+ new = new[uint16Bytes:]
+ }
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index d267dabd0..3445511f4 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -23,6 +23,7 @@ import (
"sync"
"testing"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) {
})
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
+
+func randomAddress(size int) tcpip.Address {
+ s := make([]byte, size)
+ for i := 0; i < size; i++ {
+ s[i] = byte(rand.Uint32())
+ }
+ return tcpip.Address(s)
+}
+
+func TestChecksummableNetworkUpdateAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ update func(header.IPv4, tcpip.Address)
+ }{
+ {
+ name: "SetSourceAddressWithChecksumUpdate",
+ update: header.IPv4.SetSourceAddressWithChecksumUpdate,
+ },
+ {
+ name: "SetDestinationAddressWithChecksumUpdate",
+ update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ var origBytes [header.IPv4MinimumSize]byte
+ header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
+ TOS: 1,
+ TotalLength: header.IPv4MinimumSize,
+ ID: 2,
+ Flags: 3,
+ FragmentOffset: 4,
+ TTL: 5,
+ Protocol: 6,
+ Checksum: 0,
+ SrcAddr: randomAddress(header.IPv4AddressSize),
+ DstAddr: randomAddress(header.IPv4AddressSize),
+ })
+
+ addr := randomAddress(header.IPv4AddressSize)
+
+ bytesCopy := origBytes
+ h := header.IPv4(bytesCopy[:])
+ origXSum := h.CalculateChecksum()
+ h.SetChecksum(^origXSum)
+
+ test.update(h, addr)
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+ want := h.CalculateChecksum()
+ if got != want {
+ t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
+ }
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePort(t *testing.T) {
+ // The fields in the pseudo header is not tested here so we just use 0.
+ const pseudoHeaderXSum = 0
+
+ tests := []struct {
+ name string
+ transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.TCP(make([]byte, header.TCPMinimumSize))
+ h.Encode(&header.TCPFields{
+ SrcPort: src,
+ DstPort: dst,
+ SeqNum: 1,
+ AckNum: 2,
+ DataOffset: header.TCPMinimumSize,
+ Flags: 3,
+ WindowSize: 4,
+ Checksum: 0,
+ UrgentPointer: 5,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.UDP(make([]byte, header.UDPMinimumSize))
+ h.Encode(&header.UDPFields{
+ SrcPort: src,
+ DstPort: dst,
+ Length: 0,
+ Checksum: 0,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ origSrcPort := uint16(rand.Uint32())
+ origDstPort := uint16(rand.Uint32())
+ newPort := uint16(rand.Uint32())
+
+ t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range []struct {
+ name string
+ update func(header.ChecksummableTransport)
+ }{
+ {
+ name: "Source port",
+ update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
+ },
+ {
+ name: "Destination port",
+ update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
+ },
+ } {
+ t.Run(subTest.name, func(t *testing.T) {
+ h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
+ subTest.update(h)
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+
+ if want := calcXSum(pseudoHeaderXSum); got != want {
+ h, _ := test.transportHdr(origSrcPort, origDstPort)
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
+ const addressSize = 6
+
+ tests := []struct {
+ name string
+ transportHdr func() header.ChecksummableTransport
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ permanent := randomAddress(addressSize)
+ old := randomAddress(addressSize)
+ new := randomAddress(addressSize)
+
+ t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, fullChecksum := range []bool{true, false} {
+ t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
+ initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
+ if fullChecksum {
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ initialXSum = ^initialXSum
+ }
+
+ h := test.transportHdr()
+ h.SetChecksum(initialXSum)
+ h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
+
+ got := h.Checksum()
+ if fullChecksum {
+ got = ^got
+ }
+ if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
index 861cbbb70..3a41adfc4 100644
--- a/pkg/tcpip/header/interfaces.go
+++ b/pkg/tcpip/header/interfaces.go
@@ -53,6 +53,31 @@ type Transport interface {
Payload() []byte
}
+// ChecksummableTransport is a Transport that supports checksumming.
+type ChecksummableTransport interface {
+ Transport
+
+ // SetSourcePortWithChecksumUpdate sets the source port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetSourcePortWithChecksumUpdate(port uint16)
+
+ // SetDestinationPortWithChecksumUpdate sets the destination port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetDestinationPortWithChecksumUpdate(port uint16)
+
+ // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
+ // updated address in the pseudo header.
+ //
+ // If fullChecksum is true, the receiver's checksum field is assumed to hold a
+ // fully calculated checksum. Otherwise, it is assumed to hold a partially
+ // calculated checksum which only reflects the pseudo header.
+ UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
+}
+
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
@@ -90,3 +115,16 @@ type Network interface {
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
+
+// ChecksummableNetwork is a Network that supports checksumming.
+type ChecksummableNetwork interface {
+ Network
+
+ // SetSourceAddressAndChecksum sets the source address and updates the
+ // checksum to reflect the new address.
+ SetSourceAddressWithChecksumUpdate(tcpip.Address)
+
+ // SetDestinationAddressAndChecksum sets the destination address and
+ // updates the checksum to reflect the new address.
+ SetDestinationAddressWithChecksumUpdate(tcpip.Address)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e9abbb709..dcc549c7b 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new))
+ b.SetSourceAddress(new)
+}
+
+// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new))
+ b.SetDestinationAddress(new)
+}
+
// padIPv4OptionsLength returns the total length for IPv4 options of length l
// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index b1f39e6e6..a647ea968 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -233,6 +233,17 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
case ndpNonceOptionType:
return NDPNonceOption(body), false, nil
+ case ndpRouteInformationType:
+ if numBodyBytes > ndpRouteInformationMaxLength {
+ return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody)
+ }
+ opt := NDPRouteInformation(body)
+ if err := opt.hasError(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
case ndpPrefixInformationType:
// Make sure the length of a Prefix Information option
// body is ndpPrefixInformationLength, as per RFC 4861
@@ -930,3 +941,137 @@ func isUpperLetter(b byte) bool {
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
+
+// As per RFC 4191 section 2.3,
+//
+// 2.3. Route Information Option
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Route Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Prefix (Variable Length) |
+// . .
+// . .
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
+// Fields:
+//
+// Type 24
+//
+//
+// Length 8-bit unsigned integer. The length of the option
+// (including the Type and Length fields) in units of 8
+// octets. The Length field is 1, 2, or 3 depending on the
+// Prefix Length. If Prefix Length is greater than 64, then
+// Length must be 3. If Prefix Length is greater than 0,
+// then Length must be 2 or 3. If Prefix Length is zero,
+// then Length must be 1, 2, or 3.
+const (
+ ndpRouteInformationType = ndpOptionIdentifier(24)
+ ndpRouteInformationMaxLength = 22
+
+ ndpRouteInformationPrefixLengthIdx = 0
+ ndpRouteInformationFlagsIdx = 1
+ ndpRouteInformationPrfShift = 3
+ ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift
+ ndpRouteInformationRouteLifetimeIdx = 2
+ ndpRouteInformationRoutePrefixIdx = 6
+)
+
+// NDPRouteInformation is the NDP Router Information option, as defined by
+// RFC 4191 section 2.3.
+type NDPRouteInformation []byte
+
+func (NDPRouteInformation) kind() ndpOptionIdentifier {
+ return ndpRouteInformationType
+}
+
+func (o NDPRouteInformation) length() int {
+ return len(o)
+}
+
+func (o NDPRouteInformation) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.
+func (o NDPRouteInformation) String() string {
+ return fmt.Sprintf("%T", o)
+}
+
+// PrefixLength returns the length of the prefix.
+func (o NDPRouteInformation) PrefixLength() uint8 {
+ return o[ndpRouteInformationPrefixLengthIdx]
+}
+
+// RoutePreference returns the preference of the route over other routes to the
+// same destination but through a different router.
+func (o NDPRouteInformation) RoutePreference() NDPRoutePreference {
+ return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift)
+}
+
+// RouteLifetime returns the lifetime of the route.
+//
+// Note, a value of 0 implies the route is now invalid and a value of
+// infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPRouteInformation) RouteLifetime() time.Duration {
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:]))
+}
+
+// Prefix returns the prefix of the destination subnet this route is for.
+func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) {
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max)
+ }
+
+ prefix := o[ndpRouteInformationRoutePrefixIdx:]
+ var addrBytes [IPv6AddressSize]byte
+ if n := copy(addrBytes[:], prefix); n != len(prefix) {
+ panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix)))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes[:]),
+ PrefixLen: prefixLength,
+ }.Subnet(), nil
+}
+
+func (o NDPRouteInformation) hasError() error {
+ l := len(o)
+ if l < ndpRouteInformationRoutePrefixIdx {
+ return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody)
+ }
+
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody)
+ }
+
+ // Length 8-bit unsigned integer. The length of the option
+ // (including the Type and Length fields) in units of 8
+ // octets. The Length field is 1, 2, or 3 depending on the
+ // Prefix Length. If Prefix Length is greater than 64, then
+ // Length must be 3. If Prefix Length is greater than 0,
+ // then Length must be 2 or 3. If Prefix Length is zero,
+ // then Length must be 1, 2, or 3.
+ l += 2 // Add 2 bytes for the type and length bytes.
+ lengthField := l / lengthByteUnits
+ if prefixLength > 64 {
+ if lengthField != 3 {
+ return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if prefixLength > 0 {
+ if lengthField != 2 && lengthField != 3 {
+ return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if lengthField == 0 || lengthField > 3 {
+ return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index 7e2f0c797..7d6efa083 100644
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -16,9 +16,12 @@ package header
import (
"encoding/binary"
+ "fmt"
"time"
)
+var _ fmt.Stringer = NDPRoutePreference(0)
+
// NDPRoutePreference is the preference values for default routers or
// more-specific routes.
//
@@ -64,6 +67,22 @@ const (
ReservedRoutePreference = 0b10
)
+// String implements fmt.Stringer.
+func (p NDPRoutePreference) String() string {
+ switch p {
+ case HighRoutePreference:
+ return "HighRoutePreference"
+ case MediumRoutePreference:
+ return "MediumRoutePreference"
+ case LowRoutePreference:
+ return "LowRoutePreference"
+ case ReservedRoutePreference:
+ return "ReservedRoutePreference"
+ default:
+ return fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+}
+
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 8fd1f7d13..2a897e938 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"regexp"
+ "strings"
"testing"
"time"
@@ -58,6 +59,224 @@ func TestNDPNeighborSolicit(t *testing.T) {
}
}
+func TestNDPRouteInformationOption(t *testing.T) {
+ tests := []struct {
+ name string
+
+ length uint8
+ prefixLength uint8
+ prf NDPRoutePreference
+ lifetimeS uint32
+ prefixBytes []byte
+ expectedPrefix tcpip.Subnet
+
+ expectedErr error
+ }{
+ {
+ name: "Length=1 with Prefix Length = 0",
+ length: 1,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=1 but Prefix Length > 0",
+ length: 1,
+ prefixLength: 1,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=2 with Prefix Length = 0",
+ length: 2,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (1)",
+ length: 2,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (64)",
+ length: 2,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length > 64",
+ length: 2,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=3 with Prefix Length = 0",
+ length: 3,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (1)",
+ length: 3,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (64)",
+ length: 3,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (65)",
+ length: 3,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 65,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (128)",
+ length: 3,
+ prefixLength: 128,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 128,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with (invalid) Prefix Length > 128",
+ length: 3,
+ prefixLength: 129,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ expectedRouteInformationBytes := [...]byte{
+ // Type, Length
+ 24, test.length,
+
+ // Prefix Length, Prf
+ uint8(test.prefixLength), uint8(test.prf) << 3,
+
+ // Route Lifetime
+ 0, 0, 0, 0,
+
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ }
+ binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS)
+ _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes)
+
+ opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits])
+ it, err := opts.Iter(false)
+ if err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ opt, done, err := it.Next()
+ if !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr)
+ }
+ if want := test.expectedErr != nil; done != want {
+ t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want)
+ }
+ if test.expectedErr != nil {
+ return
+ }
+
+ if got := opt.kind(); got != ndpRouteInformationType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType)
+ }
+
+ ri, ok := opt.(NDPRouteInformation)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPRouteInformation", opt)
+ }
+
+ if got := ri.PrefixLength(); got != test.prefixLength {
+ t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength)
+ }
+ if got := ri.RoutePreference(); got != test.prf {
+ t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf)
+ }
+ if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want {
+ t.Errorf("got RouteLifetime() = %s, want = %s", got, want)
+ }
+ if got, err := ri.Prefix(); err != nil {
+ t.Errorf("Prefix(): %s", err)
+ } else if got != test.expectedPrefix {
+ t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix)
+ }
+
+ // Iterator should not return anything else.
+ {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next() = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
+ }
+ }
+ })
+ }
+}
+
// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
func TestNDPNeighborAdvert(t *testing.T) {
b := []byte{
@@ -1498,3 +1717,32 @@ func TestNDPOptionsIter(t *testing.T) {
t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
}
}
+
+func TestNDPRoutePreferenceStringer(t *testing.T) {
+ p := NDPRoutePreference(0)
+ for {
+ var wantStr string
+ switch p {
+ case 0b01:
+ wantStr = "HighRoutePreference"
+ case 0b00:
+ wantStr = "MediumRoutePreference"
+ case 0b11:
+ wantStr = "LowRoutePreference"
+ case 0b10:
+ wantStr = "ReservedRoutePreference"
+ default:
+ wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+
+ if gotStr := p.String(); gotStr != wantStr {
+ t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr)
+ }
+
+ p++
+ if p == 0 {
+ // Overflowed, we hit all values.
+ break
+ }
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 8dabe3354..a75e51a28 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
b.SetChecksum(^checksum)
}
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
+
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index ae9d167ff..f69d53314 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
+
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d971194e6..1d0163823 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 735c28da1..e8e716db0 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package fdbased provides the implemention of data-link layer endpoints
@@ -44,7 +45,6 @@ import (
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -138,6 +138,20 @@ type endpoint struct {
// gsoKind is the supported kind of GSO.
gsoKind stack.SupportedGSO
+
+ // maxSyscallHeaderBytes has the same meaning as
+ // Options.MaxSyscallHeaderBytes.
+ maxSyscallHeaderBytes uintptr
+
+ // writevMaxIovs is the maximum number of iovecs that may be passed to
+ // rawfile.NonBlockingWriteIovec, as possibly limited by
+ // maxSyscallHeaderBytes. (No analogous limit is defined for
+ // rawfile.NonBlockingSendMMsg, since in that case the maximum number of
+ // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch
+ // encounters a packet whose iovec count is limited by
+ // maxSyscallHeaderBytes, it falls back to writing the packet using writev
+ // via WritePacket.)
+ writevMaxIovs int
}
// Options specify the details about the fd-based endpoint to be created.
@@ -186,6 +200,11 @@ type Options struct {
// RXChecksumOffload if true, indicates that this endpoints capability
// set should include CapabilityRXChecksumOffload.
RXChecksumOffload bool
+
+ // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes
+ // of struct iovec, msghdr, and mmsghdr that may be passed by each host
+ // system call.
+ MaxSyscallHeaderBytes int
}
// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT
@@ -235,14 +254,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
+ if opts.MaxSyscallHeaderBytes < 0 {
+ return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative")
+ }
+
e := &endpoint{
- fds: opts.FDs,
- mtu: opts.MTU,
- caps: caps,
- closed: opts.ClosedFunc,
- addr: opts.Address,
- hdrSize: hdrSize,
- packetDispatchMode: opts.PacketDispatchMode,
+ fds: opts.FDs,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes),
+ writevMaxIovs: rawfile.MaxIovs,
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs {
+ e.writevMaxIovs = max
+ }
}
// Increment fanoutID to ensure that we don't re-use the same fanoutID for
@@ -470,9 +500,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
- var builder iovec.Builder
-
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
+ var vnetHdrBuf []byte
if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
@@ -494,71 +523,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
+ vnetHdrBuf = vnetHdr.marshal()
+ }
- vnetHdrBuf := vnetHdr.marshal()
- builder.Add(vnetHdrBuf)
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > e.writevMaxIovs {
+ numIovecs = e.writevMaxIovs
}
- for _, v := range pkt.Views() {
- builder.Add(v)
+ // Allocate small iovec arrays on the stack.
+ var iovecsArr [8]unix.Iovec
+ iovecs := iovecsArr[:0]
+ if numIovecs > len(iovecsArr) {
+ iovecs = make([]unix.Iovec, 0, numIovecs)
+ }
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
}
- return rawfile.NonBlockingWriteIovec(fd, builder.Build())
+ return rawfile.NonBlockingWriteIovec(fd, iovecs)
}
-func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) {
+func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) {
// Send a batch of packets through batchFD.
- mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
- for _, pkt := range batch {
- if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
+ mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts))
+ packets := 0
+ for packets < len(pkts) {
+ mmsgHdrs := mmsgHdrsStorage
+ batch := pkts[packets:]
+ syscallHeaderBytes := uintptr(0)
+ for _, pkt := range batch {
+ if e.hdrSize > 0 {
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
- var vnetHdrBuf []byte
- if e.gsoKind == stack.HWGSOSupported {
- vnetHdr := virtioNetHdr{}
- if pkt.GSOOptions.Type != stack.GSONone {
- vnetHdr.hdrLen = uint16(pkt.HeaderSize())
- if pkt.GSOOptions.NeedsCsum {
- vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
- vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
- vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
- }
- if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
- switch pkt.GSOOptions.Type {
- case stack.GSOTCPv4:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
- case stack.GSOTCPv6:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
- default:
- panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ var vnetHdrBuf []byte
+ if e.gsoKind == stack.HWGSOSupported {
+ vnetHdr := virtioNetHdr{}
+ if pkt.GSOOptions.Type != stack.GSONone {
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
+ if pkt.GSOOptions.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
+ }
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ }
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
- vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
+ vnetHdrBuf = vnetHdr.marshal()
}
- vnetHdrBuf = vnetHdr.marshal()
- }
- var builder iovec.Builder
- builder.Add(vnetHdrBuf)
- for _, v := range pkt.Views() {
- builder.Add(v)
- }
- iovecs := builder.Build()
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > rawfile.MaxIovs {
+ numIovecs = rawfile.MaxIovs
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec
+ if syscallHeaderBytes > e.maxSyscallHeaderBytes {
+ // We can't fit this packet into this call to sendmmsg().
+ // We could potentially do so if we reduced numIovecs
+ // further, but this might incur considerable extra
+ // copying. Leave it to the next batch instead.
+ break
+ }
+ }
- var mmsgHdr rawfile.MMsgHdr
- mmsgHdr.Msg.Iov = &iovecs[0]
- mmsgHdr.Msg.SetIovlen((len(iovecs)))
- mmsgHdrs = append(mmsgHdrs, mmsgHdr)
- }
+ // We can't easily allocate iovec arrays on the stack here since
+ // they will escape this loop iteration via mmsgHdrs.
+ iovecs := make([]unix.Iovec, 0, numIovecs)
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
- packets := 0
- for len(mmsgHdrs) > 0 {
- sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
- if err != nil {
- return packets, err
+ var mmsgHdr rawfile.MMsgHdr
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ mmsgHdr.Msg.SetIovlen(len(iovecs))
+ mmsgHdrs = append(mmsgHdrs, mmsgHdr)
+ }
+
+ if len(mmsgHdrs) == 0 {
+ // We can't fit batch[0] into a mmsghdr while staying under
+ // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the
+ // mmsghdr (by using writev) and re-buffer iovecs more aggressively
+ // if necessary (by using e.writevMaxIovs instead of
+ // rawfile.MaxIovs).
+ pkt := batch[0]
+ if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return packets, err
+ }
+ packets++
+ } else {
+ for len(mmsgHdrs) > 0 {
+ sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
}
- packets += sent
- mmsgHdrs = mmsgHdrs[sent:]
}
return packets, nil
@@ -676,8 +757,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti
unix.SetNonblock(fd, true)
return &InjectableEndpoint{endpoint: endpoint{
- fds: []int{fd},
- mtu: mtu,
- caps: capabilities,
+ fds: []int{fd},
+ mtu: mtu,
+ caps: capabilities,
+ writevMaxIovs: rawfile.MaxIovs,
}}
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 8aad338b6..eccd21579 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package fdbased
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
index df14eaad1..904393faa 100644
--- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package fdbased
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 5d698a5e9..bfae34ab9 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build (linux && amd64) || (linux && arm64)
// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/fdbased/mmap_stub.go b/pkg/tcpip/link/fdbased/mmap_stub.go
index 67be52d67..9d8679502 100644
--- a/pkg/tcpip/link/fdbased/mmap_stub.go
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build !linux || (!amd64 && !arm64)
// +build !linux !amd64,!arm64
package fdbased
diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 1293f68a2..58d5dfeef 100644
--- a/pkg/tcpip/link/fdbased/mmap_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build (linux && amd64) || (linux && arm64)
// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 4b7ef3aac..ab2855a63 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package fdbased
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
index 2206fe0e6..c1438da21 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux && !amd64 && !arm64
// +build linux,!amd64,!arm64
package rawfile
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 5002245a1..da900c24b 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build ((linux && amd64) || (linux && arm64)) && go1.12 && !go1.18
// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.18
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index 9743e70ea..7e21a78d4 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package rawfile
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
index 8f4bd60da..1b88c309b 100644
--- a/pkg/tcpip/link/rawfile/errors_test.go
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package rawfile
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index ba92aedbc..53448a641 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package rawfile contains utilities for using the netstack with raw host
@@ -19,12 +20,66 @@
package rawfile
import (
+ "reflect"
"unsafe"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// SizeofIovec is the size of a unix.Iovec in bytes.
+const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
+
+// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
+// host system call in a single array.
+const MaxIovs = 1024
+
+// IovecFromBytes returns a unix.Iovec representing bs.
+//
+// Preconditions: len(bs) > 0.
+func IovecFromBytes(bs []byte) unix.Iovec {
+ iov := unix.Iovec{
+ Base: &bs[0],
+ }
+ iov.SetLen(len(bs))
+ return iov
+}
+
+func bytesFromIovec(iov unix.Iovec) (bs []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ sh.Data = uintptr(unsafe.Pointer(iov.Base))
+ sh.Len = int(iov.Len)
+ sh.Cap = int(iov.Len)
+ return
+}
+
+// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
+// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
+// max, AppendIovecFromBytes replaces the final iovec in iovs with one that
+// also includes the contents of bs. Note that this implies that
+// AppendIovecFromBytes is only usable when the returned iovec slice is used as
+// the source of a write.
+func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
+ if len(bs) == 0 {
+ return iovs
+ }
+ if len(iovs) < max {
+ return append(iovs, IovecFromBytes(bs))
+ }
+ iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
+ return iovs
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg unix.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// SizeofMMsgHdr is the size of a MMsgHdr in bytes.
+const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{})
+
// GetMTU determines the MTU of a network interface device.
func GetMTU(name string) (uint32, error) {
fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0)
@@ -137,13 +192,6 @@ func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
}
}
-// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
-type MMsgHdr struct {
- Msg unix.Msghdr
- Len uint32
- _ [4]byte
-}
-
// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
// and stores the received messages in a slice of MMsgHdr structures. If no data
// is available, it will block in a poll() syscall until the file descriptor
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
index 8e6f3e5e3..e882a128c 100644
--- a/pkg/tcpip/link/sharedmem/rx.go
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package sharedmem
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index df9a0b90a..30cf659b8 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package sharedmem provides the implemention of data-link layer endpoints
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 0f72d4e95..d6d953085 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
package sharedmem
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
index c16c19647..3bb864ed2 100644
--- a/pkg/tcpip/link/sniffer/pcap.go
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -39,8 +39,6 @@ type pcapHeader struct {
Network uint32
}
-const pcapPacketHeaderLen = 16
-
type pcapPacketHeader struct {
// Seconds is the timestamp seconds.
Seconds uint32
@@ -55,8 +53,7 @@ type pcapPacketHeader struct {
OriginalLength uint32
}
-func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
- now := time.Now()
+func newPCAPPacketHeader(now time.Time, incLen, orgLen uint32) pcapPacketHeader {
return pcapPacketHeader{
Seconds: uint32(now.Unix()),
Microseconds: uint32(now.Nanosecond() / 1000),
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 2d6a3a833..3df826f3c 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -87,11 +87,7 @@ func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoin
}
func zoneOffset() (int32, error) {
- loc, err := time.LoadLocation("Local")
- if err != nil {
- return 0, err
- }
- date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, time.Local)
_, offset := date.Zone()
return int32(offset), nil
}
@@ -117,8 +113,9 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around
// another endpoint and logs packets as they traverse the endpoint.
//
-// Packets are logged to writer in the pcap format. A sniffer created with this
-// function will not emit packets using the standard log package.
+// Each packet is written to writer in the pcap format in a single Write call
+// without synchronization. A sniffer created with this function will not emit
+// packets using the standard log package.
//
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal to snapLen will be saved in their entirety. Longer
@@ -159,27 +156,29 @@ func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumbe
if max := int(e.maxPCAPLen); length > max {
length = max
}
- if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil {
- panic(err)
- }
- write := func(b []byte) {
- if len(b) > length {
- b = b[:length]
+ packetHeader := newPCAPPacketHeader(time.Now(), uint32(length), uint32(totalLength))
+ packet := make([]byte, binary.Size(packetHeader)+length)
+ {
+ writer := tcpip.SliceWriter(packet)
+ if err := binary.Write(&writer, binary.BigEndian, packetHeader); err != nil {
+ panic(err)
}
- for len(b) != 0 {
+ for _, b := range pkt.Views() {
+ if length == 0 {
+ break
+ }
+ if len(b) > length {
+ b = b[:length]
+ }
n, err := writer.Write(b)
if err != nil {
panic(err)
}
- b = b[n:]
length -= n
}
}
- for _, v := range pkt.Views() {
- if length == 0 {
- break
- }
- write(v)
+ if _, err := writer.Write(packet); err != nil {
+ panic(err)
}
}
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 7656cca6a..4758a99ad 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -26,6 +26,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 36af2a029..d23210503 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -88,12 +89,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Input validation.
if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
prefix := "tun"
@@ -108,7 +109,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
d.endpoint = endpoint
@@ -125,7 +126,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, linuxerr.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
@@ -159,7 +160,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
}
@@ -170,7 +171,7 @@ func (d *Device) Write(data []byte) (int64, error) {
endpoint := d.endpoint
d.mu.RUnlock()
if endpoint == nil {
- return 0, syserror.EBADFD
+ return 0, linuxerr.EBADFD
}
if !endpoint.IsAttached() {
return 0, syserror.EIO
@@ -207,6 +208,15 @@ func (d *Device) Write(data []byte) (int64, error) {
protocol = pktInfoHdr.Protocol()
case ethHdr != nil:
protocol = ethHdr.Type()
+ case d.flags.TUN:
+ // TUN interface with IFF_NO_PI enabled, thus
+ // we need to determine protocol from version field
+ version := data[0] >> 4
+ if version == 4 {
+ protocol = header.IPv4ProtocolNumber
+ } else if version == 6 {
+ protocol = header.IPv6ProtocolNumber
+ }
}
// Try to determine remote link address, default zero.
@@ -233,7 +243,7 @@ func (d *Device) Read() ([]byte, error) {
endpoint := d.endpoint
d.mu.RUnlock()
if endpoint == nil {
- return nil, syserror.EBADFD
+ return nil, linuxerr.EBADFD
}
for {
@@ -264,13 +274,6 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
vv.AppendView(buffer.View(hdr))
}
- // If the packet does not already have link layer header, and the route
- // does not exist, we can't compute it. This is possibly a raw packet, tun
- // device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
- return nil, false
- }
-
// Ethernet header (TAP only).
if d.flags.TAP {
// Add ethernet header if not provided.
diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go
index 0591fbd63..db4338e79 100644
--- a/pkg/tcpip/link/tun/tun_unsafe.go
+++ b/pkg/tcpip/link/tun/tun_unsafe.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// Package tun contains methods to open TAP and TUN devices.
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 0b51563cd..1261ad414 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() {
// Precondition: m.mu must be read locked.
func (m *mockMulticastGroupProtocol) Enabled() bool {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
}
@@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
@@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index f5693defe..b1aec5312 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -344,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
defer e.mu.Unlock()
- e.mu.ndp.invalidateDefaultRouter(rtr)
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
}
// SetNDPConfigurations implements NDPEndpoint.
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index c44e4ac4e..8837d66d8 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -54,6 +54,11 @@ const (
// Advertisements, as a host.
defaultDiscoverDefaultRouters = true
+ // defaultDiscoverMoreSpecificRoutes is the default configuration for
+ // whether or not to discover more-specific routes from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverMoreSpecificRoutes = true
+
// defaultDiscoverOnLinkPrefixes is the default configuration for
// whether or not to discover on-link prefixes from incoming Router
// Advertisements' Prefix Information option, as a host.
@@ -78,13 +83,13 @@ const (
// we cannot have a negative delay.
minimumMaxRtrSolicitationDelay = 0
- // MaxDiscoveredDefaultRouters is the maximum number of discovered
- // default routers. The stack should stop discovering new routers after
- // discovering MaxDiscoveredDefaultRouters routers.
+ // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link
+ // routes. The stack should stop discovering new off-link routes after
+ // this limit is reached.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- MaxDiscoveredDefaultRouters = 10
+ MaxDiscoveredOffLinkRoutes = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
@@ -352,12 +357,18 @@ type NDPConfigurations struct {
// DiscoverDefaultRouters determines whether or not default routers are
// discovered from Router Advertisements, as per RFC 4861 section 6. This
- // configuration is ignored if HandleRAs is false.
+ // configuration is ignored if RAs will not be processed (see HandleRAs).
DiscoverDefaultRouters bool
+ // DiscoverMoreSpecificRoutes determines whether or not more specific routes
+ // are discovered from Router Advertisements, as per RFC 4191. This
+ // configuration is ignored if RAs will not be processed (see HandleRAs).
+ DiscoverMoreSpecificRoutes bool
+
// DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
// discovered from Router Advertisements' Prefix Information option, as per
- // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
+ // RFC 4861 section 6. This configuration is ignored if RAs will not be
+ // processed (see HandleRAs).
DiscoverOnLinkPrefixes bool
// AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
@@ -408,6 +419,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
HandleRAs: defaultHandleRAs,
DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverMoreSpecificRoutes: defaultDiscoverMoreSpecificRoutes,
DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
@@ -448,6 +460,11 @@ type timer struct {
timer tcpip.Timer
}
+type offLinkRoute struct {
+ dest tcpip.Subnet
+ router tcpip.Address
+}
+
// ndpState is the per-Interface NDP state.
type ndpState struct {
// Do not allow overwriting this state.
@@ -462,8 +479,8 @@ type ndpState struct {
// The DAD timers to send the next NS message, or resolve the address.
dad ip.DAD
- // The default routers discovered through Router Advertisements.
- defaultRouters map[tcpip.Address]defaultRouterState
+ // The off-link routes discovered through Router Advertisements.
+ offLinkRoutes map[offLinkRoute]offLinkRouteState
// rtrSolicitTimer is the timer used to send the next router solicitation
// message.
@@ -491,12 +508,12 @@ type ndpState struct {
temporaryAddressDesyncFactor time.Duration
}
-// defaultRouterState holds data associated with a default router discovered by
+// offLinkRouteState holds data associated with an off-link route discovered by
// a Router Advertisement (RA).
-type defaultRouterState struct {
+type offLinkRouteState struct {
prf header.NDPRoutePreference
- // Job to invalidate the default router.
+ // Job to invalidate the route.
//
// Must not be nil.
invalidationJob *tcpip.Job
@@ -727,38 +744,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
prf = header.MediumRoutePreference
}
- rtr, ok := ndp.defaultRouters[ip]
- rl := ra.RouterLifetime()
- switch {
- case !ok && rl != 0:
- // This is a new default router we are discovering.
- //
- // Only remember it if we currently know about less than
- // MaxDiscoveredDefaultRouters routers.
- if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
- ndp.rememberDefaultRouter(ip, rl, prf)
- }
-
- case ok && rl != 0:
- // This is an already discovered default router. Update
- // the invalidation job.
- rtr.invalidationJob.Cancel()
- rtr.invalidationJob.Schedule(rl)
-
- if prf != rtr.prf {
- rtr.prf = prf
-
- // Inform the integrator about router preference updates.
- ndp.ep.protocol.options.NDPDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf)
- }
-
- ndp.defaultRouters[ip] = rtr
-
- case ok && rl == 0:
- // We know about the router but it is no longer to be
- // used as a default router so invalidate it.
- ndp.invalidateDefaultRouter(ip)
- }
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf)
}
// TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
@@ -810,58 +798,107 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if opt.AutonomousAddressConfigurationFlag() {
ndp.handleAutonomousPrefixInformation(opt)
}
+
+ case header.NDPRouteInformation:
+ if !ndp.configs.DiscoverMoreSpecificRoutes {
+ continue
+ }
+
+ dest, err := opt.Prefix()
+ if err != nil {
+ panic(fmt.Sprintf("%T.Prefix(): %s", opt, err))
+ }
+
+ prf := opt.RoutePreference()
+ if prf == header.ReservedRoutePreference {
+ // As per RFC 4191 section 2.3,
+ //
+ // Prf (Route Preference)
+ // 2-bit signed integer. The Route Preference indicates
+ // whether to prefer the router associated with this prefix
+ // over others, when multiple identical prefixes (for
+ // different routers) have been received. If the Reserved
+ // (10) value is received, the Route Information Option MUST
+ // be ignored.
+ continue
+ }
+
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: dest, router: ip}, opt.RouteLifetime(), prf)
}
// TODO(b/141556115): Do (MTU) Parameter Discovery.
}
}
-// invalidateDefaultRouter invalidates a discovered default router.
+// invalidateOffLinkRoute invalidates a discovered off-link route.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
- rtr, ok := ndp.defaultRouters[ip]
-
- // Is the router still discovered?
+func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) {
+ state, ok := ndp.offLinkRoutes[route]
if !ok {
- // ...Nope, do nothing further.
return
}
- rtr.invalidationJob.Cancel()
- delete(ndp.defaultRouters, ip)
+ state.invalidationJob.Cancel()
+ delete(ndp.offLinkRoutes, route)
- // Let the integrator know a discovered default router is invalidated.
+ // Let the integrator know a discovered off-link route is invalidated.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip)
+ ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router)
}
}
-// rememberDefaultRouter remembers a newly discovered default router with IPv6
-// link-local address ip with lifetime rl.
-//
-// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
+// handleOffLinkRouteDiscovery handles the discovery of an off-link route.
//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration, prf header.NDPRoutePreference) {
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) {
ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
- // Inform the integrator when we discovered a default router.
- ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip, prf)
+ state, ok := ndp.offLinkRoutes[route]
+ switch {
+ case !ok && lifetime != 0:
+ // This is a new route we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOffLinkRoutes routers.
+ if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes {
+ // Inform the integrator when we discovered an off-link route.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+
+ state := offLinkRouteState{
+ prf: prf,
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOffLinkRoute(route)
+ }),
+ }
- state := defaultRouterState{
- prf: prf,
- invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- ndp.invalidateDefaultRouter(ip)
- }),
- }
+ state.invalidationJob.Schedule(lifetime)
+
+ ndp.offLinkRoutes[route] = state
+ }
- state.invalidationJob.Schedule(rl)
+ case ok && lifetime != 0:
+ // This is an already discovered off-link route. Update the lifetime.
+ state.invalidationJob.Cancel()
+ state.invalidationJob.Schedule(lifetime)
- ndp.defaultRouters[ip] = state
+ if prf != state.prf {
+ state.prf = prf
+
+ // Inform the integrator about route preference updates.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+ }
+
+ ndp.offLinkRoutes[route] = state
+
+ case ok && lifetime == 0:
+ // The already discovered off-link route is no longer considered valid so we
+ // invalidate it immediately.
+ ndp.invalidateOffLinkRoute(route)
+ }
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -1677,12 +1714,12 @@ func (ndp *ndpState) cleanupState() {
panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
- for router := range ndp.defaultRouters {
- ndp.invalidateDefaultRouter(router)
+ for route := range ndp.offLinkRoutes {
+ ndp.invalidateOffLinkRoute(route)
}
- if got := len(ndp.defaultRouters); got != 0 {
- panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ if got := len(ndp.offLinkRoutes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got))
}
ndp.dhcpv6Configuration = 0
@@ -1845,14 +1882,14 @@ func (ndp *ndpState) stopSolicitingRouters() {
}
func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
- if ndp.defaultRouters != nil {
+ if ndp.offLinkRoutes != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
- ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 130438f3b..f0186c64e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -93,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
ipv6EP := ep.(*endpoint)
ipv6EP.mu.Lock()
- ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour, header.MediumRoutePreference)
+ ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference)
ipv6EP.mu.Unlock()
if ndpDisp.addr != lladdr1 {
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index b7f6d52ae..fe98a52af 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -12,6 +12,7 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 854d6a6ba..fb8ef1ee2 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -122,7 +123,7 @@ type deviceToDest map[tcpip.NICID]destToCounter
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (dd deviceToDest) isAvailable(res Reservation) bool {
+func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool {
flagBits := res.Flags.Bits()
if res.BindToDevice == 0 {
intersection := FlagMask
@@ -138,6 +139,9 @@ func (dd deviceToDest) isAvailable(res Reservation) bool {
return false
}
}
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
return true
}
@@ -146,16 +150,26 @@ func (dd deviceToDest) isAvailable(res Reservation) bool {
if dests, ok := dd[0]; ok {
var count int
intersection, count = dests.intersectionFlags(res)
- if count > 0 && intersection&flagBits == 0 {
- return false
+ if count > 0 {
+ if intersection&flagBits == 0 {
+ return false
+ }
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
}
}
if dests, ok := dd[res.BindToDevice]; ok {
flags, count := dests.intersectionFlags(res)
intersection &= flags
- if count > 0 && intersection&flagBits == 0 {
- return false
+ if count > 0 {
+ if intersection&flagBits == 0 {
+ return false
+ }
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
}
}
@@ -168,12 +182,12 @@ type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (ad addrToDevice) isAvailable(res Reservation) bool {
+func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool {
if res.Addr == anyIPAddress {
// If binding to the "any" address then check that there are no
// conflicts with all addresses.
for _, devices := range ad {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
@@ -182,14 +196,14 @@ func (ad addrToDevice) isAvailable(res Reservation) bool {
// Check that there is no conflict with the "any" address.
if devices, ok := ad[anyIPAddress]; ok {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
// Check that this is no conflict with the provided address.
if devices, ok := ad[res.Addr]; ok {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
@@ -310,7 +324,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// If a port is specified, just try to reserve it for all network
// protocols.
if res.Port != 0 {
- if !pm.reserveSpecificPortLocked(res) {
+ if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) {
return 0, &tcpip.ErrPortInUse{}
}
if testPort != nil {
@@ -330,7 +344,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// A port wasn't specified, so try to find one.
return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
- if !pm.reserveSpecificPortLocked(res) {
+ if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) {
return false, nil
}
if testPort != nil {
@@ -350,12 +364,12 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// reserveSpecificPortLocked tries to reserve the given port on all given
// protocols.
-func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool {
+func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool {
// Make sure the port is available.
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
if addrs, ok := pm.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(res) {
+ if !addrs.isAvailable(res, portSpecified) {
return false
}
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index b9a24ff56..009cab643 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index ef1bfc186..c10b19aa0 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build linux
// +build linux
// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index 0ea85f9ed..5642c86f8 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -15,17 +15,11 @@
package tcpip
import (
- "math"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sync"
)
-// PacketOverheadFactor is used to multiply the value provided by the user on a
-// SetSockOpt for setting the send/receive buffer sizes sockets.
-const PacketOverheadFactor = 2
-
// SocketOptionsHandler holds methods that help define endpoint specific
// behavior for socket level socket options. These must be implemented by
// endpoints to get notified when socket level options are set.
@@ -60,7 +54,7 @@ type SocketOptionsHandler interface {
// buffer size. It also returns the newly set value.
OnSetSendBufferSize(v int64) (newSz int64)
- // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE.
+ // OnSetReceiveBufferSize is invoked by SO_RCVBUF and SO_RCVBUFFORCE.
OnSetReceiveBufferSize(v, oldSz int64) (newSz int64)
}
@@ -213,16 +207,24 @@ type SocketOptions struct {
// will not change.
getSendBufferLimits GetSendBufferLimits `state:"manual"`
+ // sendBufSizeMu protects sendBufferSize and calls to
+ // handler.OnSetSendBufferSize.
+ sendBufSizeMu sync.Mutex `state:"nosave"`
+
// sendBufferSize determines the send buffer size for this socket.
- sendBufferSize atomicbitops.AlignedAtomicInt64
+ sendBufferSize int64
// getReceiveBufferLimits provides the handler to get the min, default and
// max size for receive buffer. It is initialized at the creation time and
// will not change.
getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"`
+ // receiveBufSizeMu protects receiveBufferSize and calls to
+ // handler.OnSetReceiveBufferSize.
+ receiveBufSizeMu sync.Mutex `state:"nosave"`
+
// receiveBufferSize determines the receive buffer size for this socket.
- receiveBufferSize atomicbitops.AlignedAtomicInt64
+ receiveBufferSize int64
// mu protects the access to the below fields.
mu sync.Mutex `state:"nosave"`
@@ -612,81 +614,52 @@ func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
return nil
}
+// SendBufferLimits returns the [min, max) range of allowable send buffer
+// sizes.
+func (so *SocketOptions) SendBufferLimits() (min, max int64) {
+ limits := so.getSendBufferLimits(so.stackHandler)
+ return int64(limits.Min), int64(limits.Max)
+}
+
// GetSendBufferSize gets value for SO_SNDBUF option.
func (so *SocketOptions) GetSendBufferSize() int64 {
- return so.sendBufferSize.Load()
+ so.sendBufSizeMu.Lock()
+ defer so.sendBufSizeMu.Unlock()
+ return so.sendBufferSize
}
// SetSendBufferSize sets value for SO_SNDBUF option. notify indicates if the
// stack handler should be invoked to set the send buffer size.
func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) {
- v := sendBufferSize
-
- if !notify {
- so.sendBufferSize.Store(v)
- return
- }
-
- // Make sure the send buffer size is within the min and max
- // allowed.
- ss := so.getSendBufferLimits(so.stackHandler)
- min := int64(ss.Min)
- max := int64(ss.Max)
- // Validate the send buffer size with min and max values.
- // Multiply it by factor of 2.
- if v > max {
- v = max
- }
-
- if v < math.MaxInt32/PacketOverheadFactor {
- v *= PacketOverheadFactor
- if v < min {
- v = min
- }
- } else {
- v = math.MaxInt32
+ so.sendBufSizeMu.Lock()
+ defer so.sendBufSizeMu.Unlock()
+ if notify {
+ sendBufferSize = so.handler.OnSetSendBufferSize(sendBufferSize)
}
+ so.sendBufferSize = sendBufferSize
+}
- // Notify endpoint about change in buffer size.
- newSz := so.handler.OnSetSendBufferSize(v)
- so.sendBufferSize.Store(newSz)
+// ReceiveBufferLimits returns the [min, max) range of allowable receive buffer
+// sizes.
+func (so *SocketOptions) ReceiveBufferLimits() (min, max int64) {
+ limits := so.getReceiveBufferLimits(so.stackHandler)
+ return int64(limits.Min), int64(limits.Max)
}
// GetReceiveBufferSize gets value for SO_RCVBUF option.
func (so *SocketOptions) GetReceiveBufferSize() int64 {
- return so.receiveBufferSize.Load()
+ so.receiveBufSizeMu.Lock()
+ defer so.receiveBufSizeMu.Unlock()
+ return so.receiveBufferSize
}
-// SetReceiveBufferSize sets value for SO_RCVBUF option.
+// SetReceiveBufferSize sets the value of the SO_RCVBUF option, optionally
+// notifying the owning endpoint.
func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) {
- if !notify {
- so.receiveBufferSize.Store(receiveBufferSize)
- return
- }
-
- // Make sure the send buffer size is within the min and max
- // allowed.
- v := receiveBufferSize
- ss := so.getReceiveBufferLimits(so.stackHandler)
- min := int64(ss.Min)
- max := int64(ss.Max)
- // Validate the send buffer size with min and max values.
- if v > max {
- v = max
- }
-
- // Multiply it by factor of 2.
- if v < math.MaxInt32/PacketOverheadFactor {
- v *= PacketOverheadFactor
- if v < min {
- v = min
- }
- } else {
- v = math.MaxInt32
+ so.receiveBufSizeMu.Lock()
+ defer so.receiveBufSizeMu.Unlock()
+ if notify {
+ receiveBufferSize = so.handler.OnSetReceiveBufferSize(receiveBufferSize, so.receiveBufferSize)
}
-
- oldSz := so.receiveBufferSize.Load()
- // Notify endpoint about change in buffer size.
- newSz := so.handler.OnSetReceiveBufferSize(v, oldSz)
- so.receiveBufferSize.Store(newSz)
+ so.receiveBufferSize = receiveBufferSize
}
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ce9cebdaa..ae0bb4ace 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// or we are adding a new temporary or permanent address.
//
// The address MUST be write locked at this point.
- defer addrState.mu.Unlock()
+ defer addrState.mu.Unlock() // +checklocksforce
if permanent {
if addrState.mu.kind.IsPermanent() {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 18e0d4374..068dab7ce 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -363,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
}
@@ -405,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
+ var newAddr tcpip.Address
+ var newPort uint16
+
+ updateSRCFields := false
+
switch hook {
case Prerouting, Output:
if conn.manip == manipDestination {
switch dir {
case dirOriginal:
- tcpHeader.SetDestinationPort(conn.reply.srcPort)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
case dirReply:
- tcpHeader.SetSourcePort(conn.original.dstPort)
- netHeader.SetSourceAddress(conn.original.dstAddr)
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+
+ updateSRCFields = true
}
pkt.NatDone = true
}
@@ -422,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if conn.manip == manipSource {
switch dir {
case dirOriginal:
- tcpHeader.SetSourcePort(conn.reply.dstPort)
- netHeader.SetSourceAddress(conn.reply.dstAddr)
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+
+ updateSRCFields = true
case dirReply:
- tcpHeader.SetDestinationPort(conn.original.srcPort)
- netHeader.SetDestinationAddress(conn.original.srcAddr)
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
}
pkt.NatDone = true
}
@@ -437,29 +446,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
+ fullChecksum := false
+ updatePseudoHeader := false
switch hook {
case Prerouting, Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
+ updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ fullChecksum = true
+ updatePseudoHeader = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
+ rewritePacket(
+ netHeader,
+ tcpHeader,
+ updateSRCFields,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
// Update the state of tcb.
conn.mu.Lock()
@@ -615,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
// Don't re-unlock if both tuples are in the same bucket.
if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
return true
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 91e266de8..96cc899bb 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -133,29 +133,23 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.Port)
- // Calculate UDP checksum and set it.
if hook == Output {
- udpHeader.SetChecksum(0)
- netHeader := pkt.Network()
- netHeader.SetDestinationAddress(address)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ udpHeader,
+ false, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ rt.Port,
+ address,
+ )
+ } else {
+ udpHeader.SetDestinationPort(rt.Port)
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -214,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetChecksum(0)
- udpHeader.SetSourcePort(st.Port)
- netHeader := pkt.Network()
- netHeader.SetSourceAddress(st.Addr)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ header.UDP(pkt.TransportHeader().View()),
+ true, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ st.Port,
+ st.Addr,
+ )
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -252,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleAccept, 0
}
+
+func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
+ if updateSRCFields {
+ if fullChecksum {
+ t.SetSourcePortWithChecksumUpdate(newPort)
+ } else {
+ t.SetSourcePort(newPort)
+ }
+ } else {
+ if fullChecksum {
+ t.SetDestinationPortWithChecksumUpdate(newPort)
+ } else {
+ t.SetDestinationPort(newPort)
+ }
+ }
+
+ if updatePseudoHeader {
+ var oldAddr tcpip.Address
+ if updateSRCFields {
+ oldAddr = n.SourceAddress()
+ } else {
+ oldAddr = n.DestinationAddress()
+ }
+
+ t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
+ }
+
+ if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
+ if updateSRCFields {
+ checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
+ } else {
+ checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
+ }
+ } else if updateSRCFields {
+ n.SetSourceAddress(newAddr)
+ } else {
+ n.SetDestinationAddress(newAddr)
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index bd512d312..4d5431da1 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1152,6 +1152,39 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
})
}
+// raBufWithRIO returns a valid NDP Router Advertisement with a single Route
+// Information option.
+//
+// All fields in the RA will be zero except the RIO option.
+func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ // buf will hold the route information option after the Type and Length
+ // fields.
+ //
+ // 2.3. Route Information Option
+ //
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Route Lifetime |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Prefix (Variable Length) |
+ // . .
+ // . .
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ var buf [22]byte
+ buf[0] = uint8(prefix.PrefixLen)
+ buf[1] = byte(prf) << 3
+ binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds)
+ if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) {
+ t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address))
+ }
+ return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{
+ header.NDPRouteInformation(buf[:]),
+ })
+}
+
func TestDynamicConfigurationsDisabled(t *testing.T) {
const (
nicID = 1
@@ -1308,8 +1341,8 @@ func boolToUint64(v bool) uint64 {
return 0
}
-func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string {
- return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e))
+func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string {
+ return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e))
}
func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
@@ -1342,126 +1375,171 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b
}
}
-func TestRouterDiscovery(t *testing.T) {
+func TestOffLinkRouteDiscovery(t *testing.T) {
const nicID = 1
- testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
- ndpDisp := ndpDispatcher{
- offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handleRAs,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
+ moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16}
+ tests := []struct {
+ name string
- expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) {
- t.Helper()
+ discoverDefaultRouters bool
+ discoverMoreSpecificRoutes bool
- select {
- case e := <-ndpDisp.offLinkRouteC:
- if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ dest tcpip.Subnet
+ ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer
+ }{
+ {
+ name: "Default router discovery",
+ discoverDefaultRouters: true,
+ discoverMoreSpecificRoutes: false,
+ dest: header.IPv6EmptySubnet,
+ ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBufWithPrf(router, lifetimeSeconds, prf)
+ },
+ },
+ {
+ name: "More-specific route discovery",
+ discoverDefaultRouters: false,
+ discoverMoreSpecificRoutes: true,
+ dest: moreSpecificPrefix.Subnet(),
+ ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
- default:
- t.Fatal("expected router discovery event")
- }
- }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverDefaultRouters: test.discoverDefaultRouters,
+ DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
+ expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) {
+ t.Helper()
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.offLinkRouteC:
- var prf header.NDPRoutePreference
- if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" {
- t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
}
- default:
- t.Fatal("timed out waiting for router discovery event")
- }
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
+ expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
+ clock.Advance(timeout)
+ select {
+ case e := <-ndpDisp.offLinkRouteC:
+ var prf header.NDPRoutePreference
+ if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("timed out waiting for router discovery event")
+ }
+ }
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0))
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
- default:
- }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Rx an RA from lladdr2 with a huge lifetime and reserved preference value
- // (which should be interpreted as the default (medium) preference value).
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference))
- expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
-
- // Rx an RA from another router (lladdr3) with non-zero lifetime and
- // non-default preference value.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference))
- expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true)
-
- // Rx an RA from lladdr2 with lesser lifetime and default (medium)
- // preference value.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.offLinkRouteC:
- t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
- default:
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Rx an RA from lladdr2 with a different preference.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference))
- expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true)
-
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
-
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000))
- expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
-
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0))
- expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false)
-
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
- })
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
+ default:
+ }
+
+ // Discover an off-link route through llAddr2.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference))
+ if test.discoverMoreSpecificRoutes {
+ // The reserved value is considered invalid with more-specific route
+ // discovery so we inject the same packet but with the default
+ // (medium) preference value.
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with a reserved preference value")
+ default:
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
+ }
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime and
+ // non-default preference value.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference))
+ expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true)
+
+ // Rx an RA from lladdr2 with lesser lifetime and default (medium)
+ // preference value.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference))
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
+ default:
+ }
+
+ // Rx an RA from lladdr2 with a different preference.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true)
+
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
+ })
+ })
+ }
}
// TestRouterDiscoveryMaxRouters tests that only
-// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
const nicID = 1
@@ -1484,17 +1562,17 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5))
- if i <= ipv6.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredOffLinkRoutes {
select {
case e := <-ndpDisp.offLinkRouteC:
- if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" {
+ if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" {
t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
@@ -4583,7 +4661,7 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
)
select {
case e := <-ndpDisp.offLinkRouteC:
- if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" {
+ if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" {
t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
@@ -5278,8 +5356,9 @@ func TestRouterSolicitation(t *testing.T) {
RandSource: &randSource,
})
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ opts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, &e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
}
if addr := test.nicAddr; addr != "" {
@@ -5288,6 +5367,10 @@ func TestRouterSolicitation(t *testing.T) {
}
}
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("EnableNIC(%d): %s", nicID, err)
+ }
+
// Make sure each RS is sent at the right time.
remaining := test.maxRtrSolicit
if remaining != 0 {
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index e90c1a770..90a8ba6cf 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -380,9 +380,6 @@ type TCPSndBufState struct {
// SndClosed indicates that the endpoint has been closed for sends.
SndClosed bool
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
// PacketTooBigCount is used to notify the main protocol routine how
// many times a "packet too big" control packet is received.
PacketTooBigCount int
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index cb316d27a..f9a15efb2 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.state {
case stateInitial:
@@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index b6687911a..ab5da987a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -132,7 +132,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
// headers included. Because they're write-only, We don't need to
// register with the stack.
if !associated {
- e.ops.SetReceiveBufferSize(0, false)
+ e.ops.SetReceiveBufferSize(0, false /* notify */)
e.waiterQueue = nil
return e, nil
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d807b13b7..aa413ad05 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -330,7 +330,9 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
}
ep := h.ep
- if err := h.complete(); err != nil {
+ // N.B. the endpoint is generated above by startHandshake, and will be
+ // returned locked. This first call is forced.
+ if err := h.complete(); err != nil { // +checklocksforce
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.stats.FailedConnectionAttempts.Increment()
l.cleanupFailedHandshake(h)
@@ -364,6 +366,7 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
@@ -504,7 +507,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- if err := h.complete(); err != nil {
+ // Note that startHandshake returns a locked endpoint. The
+ // force call here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index f86450298..93ed161f9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -511,6 +511,7 @@ func (h *handshake) start() {
}
// complete completes the TCP 3-way handshake initiated by h.start().
+// +checklocks:h.ep.mu
func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
@@ -909,30 +910,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
return err
}
-func (e *endpoint) handleWrite() {
- e.sndQueueInfo.sndQueueMu.Lock()
- next := e.drainSendQueueLocked()
- e.sndQueueInfo.sndQueueMu.Unlock()
-
- e.sendData(next)
-}
-
-// Move packets from send queue to send list.
-//
-// Precondition: e.sndBufMu must be locked.
-func (e *endpoint) drainSendQueueLocked() *segment {
- first := e.sndQueueInfo.sndQueue.Front()
- if first != nil {
- e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue)
- e.sndQueueInfo.SndBufInQueue = 0
- }
- return first
-}
-
// Precondition: e.mu must be locked.
func (e *endpoint) sendData(next *segment) {
// Initialize the next segment to write if it's currently nil.
if e.snd.writeNext == nil {
+ if next == nil {
+ return
+ }
e.snd.writeNext = next
}
@@ -940,17 +924,6 @@ func (e *endpoint) sendData(next *segment) {
e.snd.sendData()
}
-func (e *endpoint) handleClose() {
- if !e.EndpointState().connected() {
- return
- }
- // Drain the send queue.
- e.handleWrite()
-
- // Mark send side as closed.
- e.snd.Closed = true
-}
-
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
@@ -1311,42 +1284,45 @@ func (e *endpoint) disableKeepaliveTimer() {
e.keepalive.Unlock()
}
-// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
-// goroutine and is responsible for sending segments and handling received
-// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
- e.mu.Lock()
- var closeTimer tcpip.Timer
- var closeWaker sleep.Waker
-
- epilogue := func() {
- // e.mu is expected to be hold upon entering this section.
- if e.snd != nil {
- e.snd.resendTimer.cleanup()
- e.snd.probeTimer.cleanup()
- e.snd.reorderTimer.cleanup()
- }
+// protocolMainLoopDone is called at the end of protocolMainLoop.
+// +checklocksrelease:e.mu
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) {
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ e.snd.probeTimer.cleanup()
+ e.snd.reorderTimer.cleanup()
+ }
- if closeTimer != nil {
- closeTimer.Stop()
- }
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
- e.completeWorkerLocked()
+ e.completeWorkerLocked()
- if e.drainDone != nil {
- close(e.drainDone)
- }
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
- e.mu.Unlock()
+ e.mu.Unlock()
- e.drainClosingSegmentQueue()
+ e.drainClosingSegmentQueue()
- // When the protocol loop exits we should wake up our waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
- }
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
+ var (
+ closeTimer tcpip.Timer
+ closeWaker sleep.Waker
+ )
+ e.mu.Lock()
if handshake {
- if err := e.h.complete(); err != nil {
+ if err := e.h.complete(); err != nil { // +checklocksforce
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -1355,8 +1331,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return err
}
}
@@ -1402,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
{
w: &e.sndQueueInfo.sndWaker,
f: func() tcpip.Error {
- e.handleWrite()
- return nil
- },
- },
- {
- w: &e.sndQueueInfo.sndCloseWaker,
- f: func() tcpip.Error {
- e.handleClose()
+ e.sendData(nil /* next */)
return nil
},
},
@@ -1507,7 +1475,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
- e.mu.Unlock()
+ e.mu.Unlock() // +checklocksforce
<-e.undrain
e.mu.Lock()
}
@@ -1568,8 +1536,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if err != nil {
e.resetConnectionLocked(err)
}
- // Lock released below.
- epilogue()
}
loop:
@@ -1593,6 +1559,7 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
case StateTimeWait:
fallthrough
@@ -1601,6 +1568,7 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
}
}
@@ -1624,13 +1592,13 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
}
e.transitionToStateCloseLocked()
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
@@ -1700,6 +1668,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
// should be executed after releasing the endpoint registrations. This is
// done in cases where a new SYN is received during TIME_WAIT that carries
// a sequence number larger than one see on the connection.
+// +checklocks:e.mu
func (e *endpoint) doTimeWait() (twReuse func()) {
// Trigger a 2 * MSL time wait state. During this period
// we will drop all incoming segments.
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index dff7cb89c..7d110516b 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
case !ep.segmentQueue.empty():
p.epQ.enqueue(ep)
}
- ep.mu.Unlock()
+ ep.mu.Unlock() // +checklocksforce
} else {
ep.newSegmentWaker.Assert()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a27e2110b..ebc88d6c3 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -293,16 +293,9 @@ type sndQueueInfo struct {
sndQueueMu sync.Mutex `state:"nosave"`
stack.TCPSndBufState
- // sndQueue holds segments that are ready to be sent.
- sndQueue segmentList `state:"wait"`
-
- // sndWaker is used to signal the protocol goroutine when segments are
- // added to the `sndQueue`.
+ // sndWaker is used to signal the protocol goroutine when there may be
+ // segments that need to be sent.
sndWaker sleep.Waker `state:"manual"`
-
- // sndCloseWaker is used to notify the protocol goroutine when the send
- // side is closed.
- sndCloseWaker sleep.Waker `state:"manual"`
}
// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata.
@@ -671,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The assumption behind spinning here being that background packet processing
// should not be holding the lock for long and spinning reduces latency as we
// avoid an expensive sleep/wakeup of of the syscall goroutine).
+// +checklocksacquire:e.mu
func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
@@ -690,7 +684,7 @@ func (e *endpoint) LockUser() {
continue
}
atomic.StoreUint32(&e.ownedByUser, 1)
- return
+ return // +checklocksforce
}
}
@@ -707,7 +701,7 @@ func (e *endpoint) LockUser() {
// protocol goroutine altogether.
//
// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) UnlockUser() {
// Lock segment queue before checking so that we avoid a race where
// segments can be queued between the time we check if queue is empty
@@ -743,12 +737,13 @@ func (e *endpoint) UnlockUser() {
}
// StopWork halts packet processing. Only to be used in tests.
+// +checklocksacquire:e.mu
func (e *endpoint) StopWork() {
e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) ResumeWork() {
e.mu.Unlock()
}
@@ -759,7 +754,7 @@ func (e *endpoint) ResumeWork() {
//
// Precondition: e.mu must be held to call this method.
func (e *endpoint) setEndpointState(state EndpointState) {
- oldstate := EndpointState(atomic.LoadUint32(&e.state))
+ oldstate := EndpointState(atomic.SwapUint32(&e.state, uint32(state)))
switch state {
case StateEstablished:
e.stack.Stats().TCP.CurrentEstablished.Increment()
@@ -776,7 +771,6 @@ func (e *endpoint) setEndpointState(state EndpointState) {
e.stack.Stats().TCP.CurrentEstablished.Decrement()
}
}
- atomic.StoreUint32(&e.state, uint32(state))
}
// EndpointState returns the current state of the endpoint.
@@ -1487,87 +1481,101 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
return avail, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
+// readFromPayloader reads a slice from the Payloader.
+// +checklocks:e.mu
+// +checklocks:e.sndQueueInfo.sndQueueMu
+func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) {
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndQueueInfo.sndQueueMu.Unlock()
+ defer e.sndQueueInfo.sndQueueMu.Lock()
- e.LockUser()
- defer e.UnlockUser()
+ e.UnlockUser()
+ defer e.LockUser()
+ }
- nextSeg, n, err := func() (*segment, int, tcpip.Error) {
- e.sndQueueInfo.sndQueueMu.Lock()
- defer e.sndQueueInfo.sndQueueMu.Unlock()
+ // Fetch data.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return nil, nil
+ }
+ v := make([]byte, avail)
+ n, err := p.Read(v)
+ if err != nil && err != io.EOF {
+ return nil, &tcpip.ErrBadBuffer{}
+ }
+ return v[:n], nil
+}
+
+// queueSegment reads data from the payloader and returns a segment to be sent.
+// +checklocks:e.mu
+func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ defer e.sndQueueInfo.sndQueueMu.Unlock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return nil, 0, err
+ }
+
+ v, err := e.readFromPayloader(p, opts, avail)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Do not queue zero length segments.
+ if len(v) == 0 {
+ return nil, 0, nil
+ }
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.stats.WriteErrors.WriteClosed.Increment()
return nil, 0, err
}
- v, err := func() ([]byte, tcpip.Error) {
- // We can release locks while copying data.
- //
- // This is not possible if atomic is set, because we can't allow the
- // available buffer space to be consumed by some other caller while we
- // are copying data in.
- if !opts.Atomic {
- e.sndQueueInfo.sndQueueMu.Unlock()
- defer e.sndQueueInfo.sndQueueMu.Lock()
-
- e.UnlockUser()
- defer e.LockUser()
- }
-
- // Fetch data.
- if l := p.Len(); l < avail {
- avail = l
- }
- if avail == 0 {
- return nil, nil
- }
- v := make([]byte, avail)
- n, err := p.Read(v)
- if err != nil && err != io.EOF {
- return nil, &tcpip.ErrBadBuffer{}
- }
- return v[:n], nil
- }()
- if len(v) == 0 || err != nil {
- return nil, 0, err
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
+ }
- if !opts.Atomic {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- avail, err := e.isEndpointWritableLocked()
- if err != nil {
- e.stats.WriteErrors.WriteClosed.Increment()
- return nil, 0, err
- }
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
+ e.sndQueueInfo.SndBufUsed += len(v)
+ e.snd.writeList.PushBack(s)
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- }
+ return s, len(v), nil
+}
- // Add data to the send queue.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
- e.sndQueueInfo.SndBufUsed += len(v)
- e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
- e.sndQueueInfo.sndQueue.PushBack(s)
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.LockUser()
+ defer e.UnlockUser()
- return e.drainSendQueueLocked(), len(v), nil
- }()
// Return if either we didn't queue anything or if an error occurred while
// attempting to queue data.
+ nextSeg, n, err := e.queueSegment(p, opts)
if n == 0 || err != nil {
return 0, err
}
+
e.sendData(nextSeg)
return int64(n), nil
}
@@ -2314,7 +2322,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
- for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} {
+ for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.TransportEndpointInfo.ID
e.sndQueueInfo.sndWaker.Assert()
@@ -2372,6 +2380,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
+ // Wake up any readers that maybe waiting for the stream to become
+ // readable.
+ e.waiterQueue.Notify(waiter.ReadableEvents)
}
// Close for write.
@@ -2388,12 +2399,20 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Queue fin segment.
s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil)
- e.sndQueueInfo.sndQueue.PushBack(s)
- e.sndQueueInfo.SndBufInQueue++
+ e.snd.writeList.PushBack(s)
// Mark endpoint as closed.
e.sndQueueInfo.SndClosed = true
e.sndQueueInfo.sndQueueMu.Unlock()
- e.handleClose()
+
+ // Drain the send queue.
+ e.sendData(s)
+
+ // Mark send side as closed.
+ e.snd.Closed = true
+
+ // Wake up any writers that maybe waiting for the stream to become
+ // writable.
+ e.waiterQueue.Notify(waiter.WritableEvents)
}
return nil
@@ -2501,6 +2520,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
+// +checklocksrelease:e.mu
func (e *endpoint) startAcceptedLoop() {
e.workerRunning = true
e.mu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 65c86823a..2e709ed78 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
return nil, err
}
- // Start the protocol goroutine.
- ep.startAcceptedLoop()
+ // Start the protocol goroutine. Note that the endpoint is returned
+ // from performHandshake locked.
+ ep.startAcceptedLoop() // +checklocksforce
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index ced3a9c58..84fb1c416 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -16,6 +16,7 @@
// iterations taking long enough that the retransmit timer can kick in causing
// the congestion window measurements to fail due to extra packets etc.
//
+//go:build !race
// +build !race
package tcp_test
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9bbe9bc3e..031f01357 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2147,7 +2147,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Bump up the receive buffer size such that, when the receive window grows,
// the scaled window exceeds maxUint16.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true)
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max)*2, true /* notify */)
// Keep the payload size < segment overhead and such that it is a multiple
// of the window scaled value. This enables the test to perform equality
@@ -2267,7 +2267,7 @@ func TestNoWindowShrinking(t *testing.T) {
initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale
initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd))
// Now shrink the receive buffer to half its original size.
- c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true)
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize), true /* notify */)
data := generateRandomPayload(t, rcvBufSize)
// Send a payload of half the size of rcvBufSize.
@@ -2523,7 +2523,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
+ ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -2595,7 +2595,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- ep.SocketOptions().SetReceiveBufferSize(65535*3, true)
+ ep.SocketOptions().SetReceiveBufferSize(65535*6, true /* notify */)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3188,7 +3188,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
+ ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
@@ -3327,7 +3327,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 3
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true)
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize*2, true /* notify */)
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
@@ -3451,17 +3451,13 @@ loop:
for {
switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
case *tcpip.ErrWouldBlock:
- select {
- case <-ch:
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
- }
- break loop
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for reset to arrive")
+ <-ch
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
}
+ break loop
case *tcpip.ErrConnectionReset:
break loop
default:
@@ -3472,14 +3468,27 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
+
+ checkValid := func() []error {
+ var errors []error
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
+ }
+ return errors
}
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
}
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ for _, err := range checkValid() {
+ t.Error(err)
}
}
@@ -3615,6 +3624,38 @@ func TestMaxRTO(t *testing.T) {
}
}
+// TestZeroSizedWriteRetransmit tests that a zero sized write should not
+// result in a panic on an RTO as no segment should have been queued for
+// a zero sized write.
+func TestZeroSizedWriteRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(context.TestInitialSequenceNumber, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ var r bytes.Reader
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ // Now do a non-zero sized write to trigger actual sending of data.
+ r.Reset(make([]byte, 1))
+ _, err = c.EP.Write(&r, tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ // Do not ACK the packet and expect an original transmit and a
+ // retransmit. This should not cause a panic.
+ for i := 0; i < 2; i++ {
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
+ ),
+ )
+ }
+}
+
// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
// unique on retransmits.
func TestRetransmitIPv4IDUniqueness(t *testing.T) {
@@ -4628,52 +4669,6 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3)
}
-func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
- })
-
- // Check the default values.
- ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
- if err != nil {
- t.Fatalf("NewEndpoint failed; %s", err)
- }
- defer ep.Close()
-
- // Change the min/max values for send/receive
- {
- opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- {
- opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err)
- }
- }
-
- // Set values below the min/2.
- ep.SocketOptions().SetReceiveBufferSize(99, true)
- checkRecvBufferSize(t, ep, 200)
-
- ep.SocketOptions().SetSendBufferSize(149, true)
-
- checkSendBufferSize(t, ep, 300)
-
- // Set values above the max.
- ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true)
- // Values above max are capped at max and then doubled.
- checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2)
-
- ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true)
- // Values above max are capped at max and then doubled.
- checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2)
-}
-
func TestBindToDeviceOption(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
@@ -6068,6 +6063,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// complete the connection to test that the large SEQ num
// did not change the state from SYN-RCVD.
+ // Get setup to be notified about connection establishment.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.ReadableEvents)
+ defer c.WQ.EventUnregister(&we)
+
// Send ACK to move to ESTABLISHED state.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -6078,32 +6078,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
+ <-ch
newEP, _, err := c.EP.Accept(nil)
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- default:
+ if err != nil {
t.Fatalf("Accept failed: %s", err)
}
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
var r strings.Reader
@@ -6209,12 +6189,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
RcvWnd: 30000,
})
- time.Sleep(50 * time.Millisecond)
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
+ checkValid := func() []error {
+ var errors []error
+ if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
+ }
+ return errors
}
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
+ }
+ for _, err := range checkValid() {
+ t.Error(err)
+ }
+ if t.Failed() {
+ t.FailNow()
}
we, ch := waiter.NewChannelEntry(nil)
@@ -6225,15 +6219,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
_, _, err = c.EP.Accept(nil)
if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
+ <-ch
+ _, _, err = c.EP.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
}
}
}
@@ -7483,7 +7472,7 @@ func TestTCPUserTimeout(t *testing.T) {
select {
case <-notifyCh:
case <-time.After(2 * initRTO):
- t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout)
+ t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
}
// No packet should be received as the connection should be silently
@@ -7717,7 +7706,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
// Increasing the buffer from should generate an ACK,
// since window grew from small value to larger equal MSS
- c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true)
+ c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*4, true /* notify */)
checker.IPv4(t, c.GetPacket(),
checker.PayloadLen(header.TCPMinimumSize),
checker.TCP(
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 53efecc5a..96e4849d2 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -757,7 +757,7 @@ func (c *Context) Create(epRcvBuf int) {
}
if epRcvBuf != -1 {
- c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */)
+ c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf)*2, true /* notify */)
}
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index def9d7186..82a3f2287 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -364,6 +364,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.EndpointState() {
case StateInitial:
@@ -380,10 +381,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
@@ -449,37 +448,20 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
- lockReleased := false
- defer func() {
- if lockReleased {
- return
- }
- e.mu.RUnlock()
- }()
+ defer e.mu.RUnlock()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
+ return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWrite(opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
if !retry {
@@ -489,34 +471,34 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
route := e.route
dstPort := e.dstPort
- if to != nil {
+ if opts.To != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicID := to.NIC
+ nicID := opts.To.NIC
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
+ return udpPacketInfo{}, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
}
- if to.Port == 0 {
+ if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, &tcpip.ErrInvalidEndpointState{}
+ return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*to)
+ dst, netProto, err := e.checkV4MappedLocked(*opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
defer r.Release()
@@ -525,12 +507,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, &tcpip.ErrBroadcastDisabled{}
+ return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
}
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
- return 0, &tcpip.ErrBadBuffer{}
+ return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
@@ -548,24 +530,39 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
v,
)
}
- return 0, &tcpip.ErrMessageTooLong{}
+ return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
ttl := e.ttl
useDefaultTTL := ttl == 0
-
if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
ttl = e.multicastTTL
// Multicast allows a 0 TTL.
useDefaultTTL = false
}
- localPort := e.ID.LocalPort
- sendTOS := e.sendTOS
- owner := e.owner
- noChecksum := e.SocketOptions().GetNoChecksum()
- lockReleased = true
- e.mu.RUnlock()
+ return udpPacketInfo{
+ route: route,
+ data: buffer.View(v),
+ localPort: e.ID.LocalPort,
+ remotePort: dstPort,
+ ttl: ttl,
+ useDefaultTTL: useDefaultTTL,
+ tos: e.sendTOS,
+ owner: e.owner,
+ noChecksum: e.SocketOptions().GetNoChecksum(),
+ }, nil
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if err := e.LastError(); err != nil {
+ return 0, err
+ }
+
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
@@ -577,10 +574,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
+ u, err := e.buildUDPPacketInfo(p, opts)
+ if err != nil {
return 0, err
}
- return int64(len(v)), nil
+ n, err := u.send()
+ if err != nil {
+ return 0, err
+ }
+ return int64(n), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -817,14 +819,30 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
return nil
}
-// sendUDP sends a UDP segment via the provided network endpoint and under the
-// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error {
+// udpPacketInfo contains all information required to send a UDP packet.
+//
+// This should be used as a value-only type, which exists in order to simplify
+// return value syntax. It should not be exported or extended.
+type udpPacketInfo struct {
+ route *stack.Route
+ data buffer.View
+ localPort uint16
+ remotePort uint16
+ ttl uint8
+ useDefaultTTL bool
+ tos uint8
+ owner tcpip.PacketOwner
+ noChecksum bool
+}
+
+// send sends the given packet.
+func (u *udpPacketInfo) send() (int, tcpip.Error) {
+ vv := u.data.ToVectorisedView()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
- Data: data,
+ ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
+ Data: vv,
})
- pkt.Owner = owner
+ pkt.Owner = u.owner
// Initialize the UDP header.
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
@@ -832,8 +850,8 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
- SrcPort: localPort,
- DstPort: remotePort,
+ SrcPort: u.localPort,
+ DstPort: u.remotePort,
Length: length,
})
@@ -841,30 +859,30 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.RequiresTXTransportChecksum() &&
- (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range data.Views() {
+ if u.route.RequiresTXTransportChecksum() &&
+ (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
+ xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range vv.Views() {
xsum = header.Checksum(v, xsum)
}
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
- if useDefaultTTL {
- ttl = r.DefaultTTL()
+ if u.useDefaultTTL {
+ u.ttl = u.route.DefaultTTL()
}
- if err := r.WritePacket(stack.NetworkHeaderParams{
+ if err := u.route.WritePacket(stack.NetworkHeaderParams{
Protocol: ProtocolNumber,
- TTL: ttl,
- TOS: tos,
+ TTL: u.ttl,
+ TOS: u.tos,
}, pkt); err != nil {
- r.Stats().UDP.PacketSendErrors.Increment()
- return err
+ u.route.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
}
// Track count of packets sent.
- r.Stats().UDP.PacketsSent.Increment()
- return nil
+ u.route.Stats().UDP.PacketsSent.Increment()
+ return len(u.data), nil
}
// checkV4MappedLocked determines the effective network protocol and converts