summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/header/checksum.go62
-rw-r--r--pkg/tcpip/header/checksum_test.go203
-rw-r--r--pkg/tcpip/header/interfaces.go38
-rw-r--r--pkg/tcpip/header/ipv4.go12
-rw-r--r--pkg/tcpip/header/ndp_options.go150
-rw-r--r--pkg/tcpip/header/ndp_router_advert.go94
-rw-r--r--pkg/tcpip/header/ndp_test.go339
-rw-r--r--pkg/tcpip/header/tcp.go29
-rw-r--r--pkg/tcpip/header/udp.go29
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go221
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go61
-rw-r--r--pkg/tcpip/link/tun/BUILD1
-rw-r--r--pkg/tcpip/link/tun/device.go29
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go10
-rw-r--r--pkg/tcpip/network/ip_test.go40
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go5
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go32
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go39
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go287
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go13
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go40
-rw-r--r--pkg/tcpip/stack/BUILD2
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go2
-rw-r--r--pkg/tcpip/stack/conntrack.go67
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go105
-rw-r--r--pkg/tcpip/stack/iptables_types.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go1778
-rw-r--r--pkg/tcpip/stack/nic.go21
-rw-r--r--pkg/tcpip/stack/packet_buffer.go2
-rw-r--r--pkg/tcpip/stack/registration.go5
-rw-r--r--pkg/tcpip/stack/stack.go7
-rw-r--r--pkg/tcpip/stack/tcp.go3
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go23
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go50
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go285
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go11
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go9
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go20
-rw-r--r--pkg/tcpip/transport/tcp/accept.go23
-rw-r--r--pkg/tcpip/transport/tcp/connect.go141
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go171
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go5
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go1
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go157
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go150
55 files changed, 3056 insertions, 1750 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index ed4d7e958..f00cfd0f5 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -46,7 +46,6 @@ deps_test(
"//pkg/gohacks",
"//pkg/goid",
"//pkg/ilist",
- "//pkg/iovec",
"//pkg/linewriter",
"//pkg/log",
"//pkg/rand",
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 6aa9acfa8..e2c85e220 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -18,6 +18,7 @@ package header
import (
"encoding/binary"
+ "fmt"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.
return Checksum([]byte{0, uint8(protocol)}, xsum)
}
+
+// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
+// checksum.
+//
+// The value MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ return ChecksumCombine(xsum, ChecksumCombine(new, ^old))
+}
+
+// checksumUpdate2ByteAlignedAddress updates an address in a calculated
+// checksum.
+//
+// The addresses must have the same length and must contain an even number
+// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
+func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
+ const uint16Bytes = 2
+
+ if len(old) != len(new) {
+ panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new)))
+ }
+
+ if len(old)%uint16Bytes != 0 {
+ panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old)))
+ }
+
+ // As per RFC 1071 page 4,
+ // (4) Incremental Update
+ //
+ // ...
+ //
+ // To update the checksum, simply add the differences of the
+ // sixteen bit integers that have been changed. To see why this
+ // works, observe that every 16-bit integer has an additive inverse
+ // and that addition is associative. From this it follows that
+ // given the original value m, the new value m', and the old
+ // checksum C, the new checksum C' is:
+ //
+ // C' = C + (-m) + m' = C + (m' - m)
+ for len(old) != 0 {
+ // Convert the 2 byte sequences to uint16 values then apply the increment
+ // update.
+ xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1]))
+ old = old[uint16Bytes:]
+ new = new[uint16Bytes:]
+ }
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index d267dabd0..3445511f4 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -23,6 +23,7 @@ import (
"sync"
"testing"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) {
})
}, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView()))
}
+
+func randomAddress(size int) tcpip.Address {
+ s := make([]byte, size)
+ for i := 0; i < size; i++ {
+ s[i] = byte(rand.Uint32())
+ }
+ return tcpip.Address(s)
+}
+
+func TestChecksummableNetworkUpdateAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ update func(header.IPv4, tcpip.Address)
+ }{
+ {
+ name: "SetSourceAddressWithChecksumUpdate",
+ update: header.IPv4.SetSourceAddressWithChecksumUpdate,
+ },
+ {
+ name: "SetDestinationAddressWithChecksumUpdate",
+ update: header.IPv4.SetDestinationAddressWithChecksumUpdate,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for i := 0; i < 1000; i++ {
+ var origBytes [header.IPv4MinimumSize]byte
+ header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{
+ TOS: 1,
+ TotalLength: header.IPv4MinimumSize,
+ ID: 2,
+ Flags: 3,
+ FragmentOffset: 4,
+ TTL: 5,
+ Protocol: 6,
+ Checksum: 0,
+ SrcAddr: randomAddress(header.IPv4AddressSize),
+ DstAddr: randomAddress(header.IPv4AddressSize),
+ })
+
+ addr := randomAddress(header.IPv4AddressSize)
+
+ bytesCopy := origBytes
+ h := header.IPv4(bytesCopy[:])
+ origXSum := h.CalculateChecksum()
+ h.SetChecksum(^origXSum)
+
+ test.update(h, addr)
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+ want := h.CalculateChecksum()
+ if got != want {
+ t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr)
+ }
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePort(t *testing.T) {
+ // The fields in the pseudo header is not tested here so we just use 0.
+ const pseudoHeaderXSum = 0
+
+ tests := []struct {
+ name string
+ transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16)
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.TCP(make([]byte, header.TCPMinimumSize))
+ h.Encode(&header.TCPFields{
+ SrcPort: src,
+ DstPort: dst,
+ SeqNum: 1,
+ AckNum: 2,
+ DataOffset: header.TCPMinimumSize,
+ Flags: 3,
+ WindowSize: 4,
+ Checksum: 0,
+ UrgentPointer: 5,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) {
+ h := header.UDP(make([]byte, header.UDPMinimumSize))
+ h.Encode(&header.UDPFields{
+ SrcPort: src,
+ DstPort: dst,
+ Length: 0,
+ Checksum: 0,
+ })
+ h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum))
+ return h, h.CalculateChecksum
+ },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ origSrcPort := uint16(rand.Uint32())
+ origDstPort := uint16(rand.Uint32())
+ newPort := uint16(rand.Uint32())
+
+ t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range []struct {
+ name string
+ update func(header.ChecksummableTransport)
+ }{
+ {
+ name: "Source port",
+ update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) },
+ },
+ {
+ name: "Destination port",
+ update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) },
+ },
+ } {
+ t.Run(subTest.name, func(t *testing.T) {
+ h, calcXSum := test.transportHdr(origSrcPort, origDstPort)
+ subTest.update(h)
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ got := ^h.Checksum()
+ h.SetChecksum(0)
+
+ if want := calcXSum(pseudoHeaderXSum); got != want {
+ h, _ := test.transportHdr(origSrcPort, origDstPort)
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) {
+ const addressSize = 6
+
+ tests := []struct {
+ name string
+ transportHdr func() header.ChecksummableTransport
+ proto tcpip.TransportProtocolNumber
+ }{
+ {
+ name: "TCP",
+ transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) },
+ proto: header.TCPProtocolNumber,
+ },
+ {
+ name: "UDP",
+ transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) },
+ proto: header.UDPProtocolNumber,
+ },
+ }
+
+ for i := 0; i < 1000; i++ {
+ permanent := randomAddress(addressSize)
+ old := randomAddress(addressSize)
+ new := randomAddress(addressSize)
+
+ t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, fullChecksum := range []bool{true, false} {
+ t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) {
+ initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0)
+ if fullChecksum {
+ // TCP and UDP hold the 1s complement of the fully calculated
+ // checksum.
+ initialXSum = ^initialXSum
+ }
+
+ h := test.transportHdr()
+ h.SetChecksum(initialXSum)
+ h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum)
+
+ got := h.Checksum()
+ if fullChecksum {
+ got = ^got
+ }
+ if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want {
+ t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h)
+ }
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
index 861cbbb70..3a41adfc4 100644
--- a/pkg/tcpip/header/interfaces.go
+++ b/pkg/tcpip/header/interfaces.go
@@ -53,6 +53,31 @@ type Transport interface {
Payload() []byte
}
+// ChecksummableTransport is a Transport that supports checksumming.
+type ChecksummableTransport interface {
+ Transport
+
+ // SetSourcePortWithChecksumUpdate sets the source port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetSourcePortWithChecksumUpdate(port uint16)
+
+ // SetDestinationPortWithChecksumUpdate sets the destination port and updates
+ // the checksum.
+ //
+ // The receiver's checksum must be a fully calculated checksum.
+ SetDestinationPortWithChecksumUpdate(port uint16)
+
+ // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
+ // updated address in the pseudo header.
+ //
+ // If fullChecksum is true, the receiver's checksum field is assumed to hold a
+ // fully calculated checksum. Otherwise, it is assumed to hold a partially
+ // calculated checksum which only reflects the pseudo header.
+ UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
+}
+
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
@@ -90,3 +115,16 @@ type Network interface {
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
+
+// ChecksummableNetwork is a Network that supports checksumming.
+type ChecksummableNetwork interface {
+ Network
+
+ // SetSourceAddressAndChecksum sets the source address and updates the
+ // checksum to reflect the new address.
+ SetSourceAddressWithChecksumUpdate(tcpip.Address)
+
+ // SetDestinationAddressAndChecksum sets the destination address and
+ // updates the checksum to reflect the new address.
+ SetDestinationAddressWithChecksumUpdate(tcpip.Address)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e9abbb709..dcc549c7b 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -305,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
+// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new))
+ b.SetSourceAddress(new)
+}
+
+// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork.
+func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) {
+ b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new))
+ b.SetDestinationAddress(new)
+}
+
// padIPv4OptionsLength returns the total length for IPv4 options of length l
// after applying padding according to RFC 791:
// The internet header padding is used to ensure that the internet
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index d6cad3a94..a647ea968 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -148,15 +148,10 @@ const (
// NDP option. That is, the length field for NDP options is in units of
// 8 octets, as per RFC 4861 section 4.6.
lengthByteUnits = 8
-)
-var (
// NDPInfiniteLifetime is a value that represents infinity for the
// 4-byte lifetime fields found in various NDP options. Its value is
// (2^32 - 1)s = 4294967295s.
- //
- // This is a variable instead of a constant so that tests can change
- // this value to a smaller value. It should only be modified by tests.
NDPInfiniteLifetime = time.Second * math.MaxUint32
)
@@ -238,6 +233,17 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
case ndpNonceOptionType:
return NDPNonceOption(body), false, nil
+ case ndpRouteInformationType:
+ if numBodyBytes > ndpRouteInformationMaxLength {
+ return nil, true, fmt.Errorf("got %d bytes for NDP Route Information option's body, expected at max %d bytes: %w", numBodyBytes, ndpRouteInformationMaxLength, ErrNDPOptMalformedBody)
+ }
+ opt := NDPRouteInformation(body)
+ if err := opt.hasError(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
case ndpPrefixInformationType:
// Make sure the length of a Prefix Information option
// body is ndpPrefixInformationLength, as per RFC 4861
@@ -935,3 +941,137 @@ func isUpperLetter(b byte) bool {
func isDigit(b byte) bool {
return b >= '0' && b <= '9'
}
+
+// As per RFC 4191 section 2.3,
+//
+// 2.3. Route Information Option
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Route Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Prefix (Variable Length) |
+// . .
+// . .
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+//
+// Fields:
+//
+// Type 24
+//
+//
+// Length 8-bit unsigned integer. The length of the option
+// (including the Type and Length fields) in units of 8
+// octets. The Length field is 1, 2, or 3 depending on the
+// Prefix Length. If Prefix Length is greater than 64, then
+// Length must be 3. If Prefix Length is greater than 0,
+// then Length must be 2 or 3. If Prefix Length is zero,
+// then Length must be 1, 2, or 3.
+const (
+ ndpRouteInformationType = ndpOptionIdentifier(24)
+ ndpRouteInformationMaxLength = 22
+
+ ndpRouteInformationPrefixLengthIdx = 0
+ ndpRouteInformationFlagsIdx = 1
+ ndpRouteInformationPrfShift = 3
+ ndpRouteInformationPrfMask = 3 << ndpRouteInformationPrfShift
+ ndpRouteInformationRouteLifetimeIdx = 2
+ ndpRouteInformationRoutePrefixIdx = 6
+)
+
+// NDPRouteInformation is the NDP Router Information option, as defined by
+// RFC 4191 section 2.3.
+type NDPRouteInformation []byte
+
+func (NDPRouteInformation) kind() ndpOptionIdentifier {
+ return ndpRouteInformationType
+}
+
+func (o NDPRouteInformation) length() int {
+ return len(o)
+}
+
+func (o NDPRouteInformation) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.
+func (o NDPRouteInformation) String() string {
+ return fmt.Sprintf("%T", o)
+}
+
+// PrefixLength returns the length of the prefix.
+func (o NDPRouteInformation) PrefixLength() uint8 {
+ return o[ndpRouteInformationPrefixLengthIdx]
+}
+
+// RoutePreference returns the preference of the route over other routes to the
+// same destination but through a different router.
+func (o NDPRouteInformation) RoutePreference() NDPRoutePreference {
+ return NDPRoutePreference((o[ndpRouteInformationFlagsIdx] & ndpRouteInformationPrfMask) >> ndpRouteInformationPrfShift)
+}
+
+// RouteLifetime returns the lifetime of the route.
+//
+// Note, a value of 0 implies the route is now invalid and a value of
+// infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPRouteInformation) RouteLifetime() time.Duration {
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRouteInformationRouteLifetimeIdx:]))
+}
+
+// Prefix returns the prefix of the destination subnet this route is for.
+func (o NDPRouteInformation) Prefix() (tcpip.Subnet, error) {
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return tcpip.Subnet{}, fmt.Errorf("got prefix length = %d, want <= %d", prefixLength, max)
+ }
+
+ prefix := o[ndpRouteInformationRoutePrefixIdx:]
+ var addrBytes [IPv6AddressSize]byte
+ if n := copy(addrBytes[:], prefix); n != len(prefix) {
+ panic(fmt.Sprintf("got copy(addrBytes, prefix) = %d, want = %d", n, len(prefix)))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes[:]),
+ PrefixLen: prefixLength,
+ }.Subnet(), nil
+}
+
+func (o NDPRouteInformation) hasError() error {
+ l := len(o)
+ if l < ndpRouteInformationRoutePrefixIdx {
+ return fmt.Errorf("%T too small, got = %d bytes: %w", o, l, ErrNDPOptMalformedBody)
+ }
+
+ prefixLength := int(o.PrefixLength())
+ if max := IPv6AddressSize * 8; prefixLength > max {
+ return fmt.Errorf("got prefix length = %d, want <= %d: %w", prefixLength, max, ErrNDPOptMalformedBody)
+ }
+
+ // Length 8-bit unsigned integer. The length of the option
+ // (including the Type and Length fields) in units of 8
+ // octets. The Length field is 1, 2, or 3 depending on the
+ // Prefix Length. If Prefix Length is greater than 64, then
+ // Length must be 3. If Prefix Length is greater than 0,
+ // then Length must be 2 or 3. If Prefix Length is zero,
+ // then Length must be 1, 2, or 3.
+ l += 2 // Add 2 bytes for the type and length bytes.
+ lengthField := l / lengthByteUnits
+ if prefixLength > 64 {
+ if lengthField != 3 {
+ return fmt.Errorf("Length field must be 3 when Prefix Length (%d) is > 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if prefixLength > 0 {
+ if lengthField != 2 && lengthField != 3 {
+ return fmt.Errorf("Length field must be 2 or 3 when Prefix Length (%d) is between 0 and 64 (got = %d): %w", prefixLength, lengthField, ErrNDPOptMalformedBody)
+ }
+ } else if lengthField == 0 || lengthField > 3 {
+ return fmt.Errorf("Length field must be 1, 2, or 3 when Prefix Length is zero (got = %d): %w", lengthField, ErrNDPOptMalformedBody)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go
index bf7610863..7d6efa083 100644
--- a/pkg/tcpip/header/ndp_router_advert.go
+++ b/pkg/tcpip/header/ndp_router_advert.go
@@ -16,15 +16,94 @@ package header
import (
"encoding/binary"
+ "fmt"
"time"
)
+var _ fmt.Stringer = NDPRoutePreference(0)
+
+// NDPRoutePreference is the preference values for default routers or
+// more-specific routes.
+//
+// As per RFC 4191 section 2.1,
+//
+// Default router preferences and preferences for more-specific routes
+// are encoded the same way.
+//
+// Preference values are encoded as a two-bit signed integer, as
+// follows:
+//
+// 01 High
+// 00 Medium (default)
+// 11 Low
+// 10 Reserved - MUST NOT be sent
+//
+// Note that implementations can treat the value as a two-bit signed
+// integer.
+//
+// Having just three values reinforces that they are not metrics and
+// more values do not appear to be necessary for reasonable scenarios.
+type NDPRoutePreference uint8
+
+const (
+ // HighRoutePreference indicates a high preference, as per
+ // RFC 4191 section 2.1.
+ HighRoutePreference NDPRoutePreference = 0b01
+
+ // MediumRoutePreference indicates a medium preference, as per
+ // RFC 4191 section 2.1.
+ //
+ // This is the default preference value.
+ MediumRoutePreference = 0b00
+
+ // LowRoutePreference indicates a low preference, as per
+ // RFC 4191 section 2.1.
+ LowRoutePreference = 0b11
+
+ // ReservedRoutePreference is a reserved preference value, as per
+ // RFC 4191 section 2.1.
+ //
+ // It MUST NOT be sent.
+ ReservedRoutePreference = 0b10
+)
+
+// String implements fmt.Stringer.
+func (p NDPRoutePreference) String() string {
+ switch p {
+ case HighRoutePreference:
+ return "HighRoutePreference"
+ case MediumRoutePreference:
+ return "MediumRoutePreference"
+ case LowRoutePreference:
+ return "LowRoutePreference"
+ case ReservedRoutePreference:
+ return "ReservedRoutePreference"
+ default:
+ return fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+}
+
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
-// See RFC 4861 section 4.2 for more details.
+// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
type NDPRouterAdvert []byte
+// As per RFC 4191 section 2.2,
+//
+// 0 1 2 3
+// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Type | Code | Checksum |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Reachable Time |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Retrans Timer |
+// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+// | Options ...
+// +-+-+-+-+-+-+-+-+-+-+-+-
const (
// NDPRAMinimumSize is the minimum size of a valid NDP Router
// Advertisement message (body of an ICMPv6 packet).
@@ -47,6 +126,14 @@ const (
// within the bit-field/flags byte of an NDPRouterAdvert.
ndpRAOtherConfFlagMask = (1 << 6)
+ // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceShift = 3
+
+ // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
+ // Preference) field within the flags byte of an NDPRouterAdvert.
+ ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
+
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
// field within an NDPRouterAdvert.
ndpRARouterLifetimeOffset = 2
@@ -80,6 +167,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
}
+// DefaultRouterPreference returns the Default Router Preference field.
+func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
+ return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
+}
+
// RouterLifetime returns the lifetime associated with the default router. A
// value of 0 means the source of the Router Advertisement is not a default
// router and SHOULD NOT appear on the default router list. Note, a value of 0
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 1b5093e58..2a897e938 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -21,6 +21,7 @@ import (
"fmt"
"io"
"regexp"
+ "strings"
"testing"
"time"
@@ -58,6 +59,224 @@ func TestNDPNeighborSolicit(t *testing.T) {
}
}
+func TestNDPRouteInformationOption(t *testing.T) {
+ tests := []struct {
+ name string
+
+ length uint8
+ prefixLength uint8
+ prf NDPRoutePreference
+ lifetimeS uint32
+ prefixBytes []byte
+ expectedPrefix tcpip.Subnet
+
+ expectedErr error
+ }{
+ {
+ name: "Length=1 with Prefix Length = 0",
+ length: 1,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=1 but Prefix Length > 0",
+ length: 1,
+ prefixLength: 1,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=2 with Prefix Length = 0",
+ length: 2,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (1)",
+ length: 2,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length in [1, 64] (64)",
+ length: 2,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=2 with Prefix Length > 64",
+ length: 2,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "Length=3 with Prefix Length = 0",
+ length: 3,
+ prefixLength: 0,
+ prf: MediumRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: IPv6EmptySubnet,
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (1)",
+ length: 3,
+ prefixLength: 1,
+ prf: LowRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 1,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [1, 64] (64)",
+ length: 3,
+ prefixLength: 64,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 64,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (65)",
+ length: 3,
+ prefixLength: 65,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 65,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with Prefix Length in [65, 128] (128)",
+ length: 3,
+ prefixLength: 128,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(strings.Repeat("\x00", IPv6AddressSize)),
+ PrefixLen: 128,
+ }.Subnet(),
+ },
+ {
+ name: "Length=3 with (invalid) Prefix Length > 128",
+ length: 3,
+ prefixLength: 129,
+ prf: HighRoutePreference,
+ lifetimeS: 1,
+ prefixBytes: nil,
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ expectedRouteInformationBytes := [...]byte{
+ // Type, Length
+ 24, test.length,
+
+ // Prefix Length, Prf
+ uint8(test.prefixLength), uint8(test.prf) << 3,
+
+ // Route Lifetime
+ 0, 0, 0, 0,
+
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ }
+ binary.BigEndian.PutUint32(expectedRouteInformationBytes[4:], test.lifetimeS)
+ _ = copy(expectedRouteInformationBytes[8:], test.prefixBytes)
+
+ opts := NDPOptions(expectedRouteInformationBytes[:test.length*lengthByteUnits])
+ it, err := opts.Iter(false)
+ if err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ opt, done, err := it.Next()
+ if !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Next() = (_, _, %s), want = (_, _, %s)", err, test.expectedErr)
+ }
+ if want := test.expectedErr != nil; done != want {
+ t.Fatalf("got Next() = (_, %t, _), want = (_, %t, _)", done, want)
+ }
+ if test.expectedErr != nil {
+ return
+ }
+
+ if got := opt.kind(); got != ndpRouteInformationType {
+ t.Errorf("got kind() = %d, want = %d", got, ndpRouteInformationType)
+ }
+
+ ri, ok := opt.(NDPRouteInformation)
+ if !ok {
+ t.Fatalf("got opt = %T, want = NDPRouteInformation", opt)
+ }
+
+ if got := ri.PrefixLength(); got != test.prefixLength {
+ t.Errorf("got PrefixLength() = %d, want = %d", got, test.prefixLength)
+ }
+ if got := ri.RoutePreference(); got != test.prf {
+ t.Errorf("got RoutePreference() = %d, want = %d", got, test.prf)
+ }
+ if got, want := ri.RouteLifetime(), time.Duration(test.lifetimeS)*time.Second; got != want {
+ t.Errorf("got RouteLifetime() = %s, want = %s", got, want)
+ }
+ if got, err := ri.Prefix(); err != nil {
+ t.Errorf("Prefix(): %s", err)
+ } else if got != test.expectedPrefix {
+ t.Errorf("got Prefix() = %s, want = %s", got, test.expectedPrefix)
+ }
+
+ // Iterator should not return anything else.
+ {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next() = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next() = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next() = (%x, _, _), want = (nil, _, _)", next)
+ }
+ }
+ })
+ }
+}
+
// TestNDPNeighborAdvert tests the functions of NDPNeighborAdvert.
func TestNDPNeighborAdvert(t *testing.T) {
b := []byte{
@@ -126,36 +345,83 @@ func TestNDPNeighborAdvert(t *testing.T) {
}
func TestNDPRouterAdvert(t *testing.T) {
- b := []byte{
- 64, 128, 1, 2,
- 3, 4, 5, 6,
- 7, 8, 9, 10,
+ tests := []struct {
+ hopLimit uint8
+ managedFlag, otherConfFlag bool
+ prf NDPRoutePreference
+ routerLifetimeS uint16
+ reachableTimeMS, retransTimerMS uint32
+ }{
+ {
+ hopLimit: 1,
+ managedFlag: false,
+ otherConfFlag: true,
+ prf: HighRoutePreference,
+ routerLifetimeS: 2,
+ reachableTimeMS: 3,
+ retransTimerMS: 4,
+ },
+ {
+ hopLimit: 64,
+ managedFlag: true,
+ otherConfFlag: false,
+ prf: LowRoutePreference,
+ routerLifetimeS: 258,
+ reachableTimeMS: 78492,
+ retransTimerMS: 13213,
+ },
}
- ra := NDPRouterAdvert(b)
+ for i, test := range tests {
+ t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
+ flags := uint8(0)
+ if test.managedFlag {
+ flags |= 1 << 7
+ }
+ if test.otherConfFlag {
+ flags |= 1 << 6
+ }
+ flags |= uint8(test.prf) << 3
+
+ b := []byte{
+ test.hopLimit, flags, 1, 2,
+ 3, 4, 5, 6,
+ 7, 8, 9, 10,
+ }
+ binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS)
+ binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS)
+ binary.BigEndian.PutUint32(b[8:], test.retransTimerMS)
- if got := ra.CurrHopLimit(); got != 64 {
- t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
- }
+ ra := NDPRouterAdvert(b)
- if got := ra.ManagedAddrConfFlag(); !got {
- t.Errorf("got ManagedAddrConfFlag = false, want = true")
- }
+ if got := ra.CurrHopLimit(); got != test.hopLimit {
+ t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit)
+ }
- if got := ra.OtherConfFlag(); got {
- t.Errorf("got OtherConfFlag = true, want = false")
- }
+ if got := ra.ManagedAddrConfFlag(); got != test.managedFlag {
+ t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag)
+ }
- if got, want := ra.RouterLifetime(), time.Second*258; got != want {
- t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
- }
+ if got := ra.OtherConfFlag(); got != test.otherConfFlag {
+ t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag)
+ }
- if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
- t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
- }
+ if got := ra.DefaultRouterPreference(); got != test.prf {
+ t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf)
+ }
+
+ if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want {
+ t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want)
+ }
+
+ if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want {
+ t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want)
+ }
- if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
- t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want {
+ t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -1451,3 +1717,32 @@ func TestNDPOptionsIter(t *testing.T) {
t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
}
}
+
+func TestNDPRoutePreferenceStringer(t *testing.T) {
+ p := NDPRoutePreference(0)
+ for {
+ var wantStr string
+ switch p {
+ case 0b01:
+ wantStr = "HighRoutePreference"
+ case 0b00:
+ wantStr = "MediumRoutePreference"
+ case 0b11:
+ wantStr = "LowRoutePreference"
+ case 0b10:
+ wantStr = "ReservedRoutePreference"
+ default:
+ wantStr = fmt.Sprintf("NDPRoutePreference(%d)", p)
+ }
+
+ if gotStr := p.String(); gotStr != wantStr {
+ t.Errorf("got NDPRoutePreference(%d).String() = %s, want = %s", p, gotStr, wantStr)
+ }
+
+ p++
+ if p == 0 {
+ // Overflowed, we hit all values.
+ break
+ }
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 8dabe3354..a75e51a28 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -390,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32
b.SetChecksum(^checksum)
}
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
+
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index ae9d167ff..f69d53314 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
+
+// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
+ old := b.SourcePort()
+ b.SetSourcePort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
+func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
+ old := b.DestinationPort()
+ b.SetDestinationPort(new)
+ b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
+}
+
+// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
+func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
+ xsum := b.Checksum()
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
+ if fullChecksum {
+ xsum = ^xsum
+ }
+
+ b.SetChecksum(xsum)
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d971194e6..1d0163823 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -14,7 +14,6 @@ go_library(
],
visibility = ["//visibility:public"],
deps = [
- "//pkg/iovec",
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index bddb1d0a2..1b56d2b72 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,11 +41,9 @@ package fdbased
import (
"fmt"
- "math"
"sync/atomic"
"golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -139,6 +137,20 @@ type endpoint struct {
// gsoKind is the supported kind of GSO.
gsoKind stack.SupportedGSO
+
+ // maxSyscallHeaderBytes has the same meaning as
+ // Options.MaxSyscallHeaderBytes.
+ maxSyscallHeaderBytes uintptr
+
+ // writevMaxIovs is the maximum number of iovecs that may be passed to
+ // rawfile.NonBlockingWriteIovec, as possibly limited by
+ // maxSyscallHeaderBytes. (No analogous limit is defined for
+ // rawfile.NonBlockingSendMMsg, since in that case the maximum number of
+ // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch
+ // encounters a packet whose iovec count is limited by
+ // maxSyscallHeaderBytes, it falls back to writing the packet using writev
+ // via WritePacket.)
+ writevMaxIovs int
}
// Options specify the details about the fd-based endpoint to be created.
@@ -187,6 +199,11 @@ type Options struct {
// RXChecksumOffload if true, indicates that this endpoints capability
// set should include CapabilityRXChecksumOffload.
RXChecksumOffload bool
+
+ // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes
+ // of struct iovec, msghdr, and mmsghdr that may be passed by each host
+ // system call.
+ MaxSyscallHeaderBytes int
}
// fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT
@@ -196,8 +213,12 @@ type Options struct {
// option for an FD with a fanoutID already in use by another FD for a different
// NIC will return an EINVAL.
//
+// Since fanoutID must be unique within the network namespace, we start with
+// the PID to avoid collisions. The only way to be sure of avoiding collisions
+// is to run in a new network namespace.
+//
// Must be accessed using atomic operations.
-var fanoutID int32 = 0
+var fanoutID int32 = int32(unix.Getpid())
// New creates a new fd-based endpoint.
//
@@ -232,14 +253,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) {
return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
+ if opts.MaxSyscallHeaderBytes < 0 {
+ return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative")
+ }
+
e := &endpoint{
- fds: opts.FDs,
- mtu: opts.MTU,
- caps: caps,
- closed: opts.ClosedFunc,
- addr: opts.Address,
- hdrSize: hdrSize,
- packetDispatchMode: opts.PacketDispatchMode,
+ fds: opts.FDs,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ packetDispatchMode: opts.PacketDispatchMode,
+ maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes),
+ writevMaxIovs: rawfile.MaxIovs,
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs {
+ e.writevMaxIovs = max
+ }
}
// Increment fanoutID to ensure that we don't re-use the same fanoutID for
@@ -292,11 +324,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
}
switch sa.(type) {
case *unix.SockaddrLinklayer:
- // See: PACKET_FANOUT_MAX in net/packet/internal.h
- const packetFanoutMax = 1 << 16
- if fID > packetFanoutMax {
- return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16)
- }
// Enable PACKET_FANOUT mode if the underlying socket is of type
// AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will
// prevent gvisor from receiving fragmented packets and the host does the
@@ -317,7 +344,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin
//
// See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881
const fanoutType = unix.PACKET_FANOUT_HASH
- fanoutArg := int(fID) | fanoutType<<16
+ fanoutArg := (int(fID) & 0xffff) | fanoutType<<16
if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil {
return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err)
}
@@ -472,9 +499,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
}
- var builder iovec.Builder
-
fd := e.fds[pkt.Hash%uint32(len(e.fds))]
+ var vnetHdrBuf []byte
if e.gsoKind == stack.HWGSOSupported {
vnetHdr := virtioNetHdr{}
if pkt.GSOOptions.Type != stack.GSONone {
@@ -496,71 +522,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol
vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
+ vnetHdrBuf = vnetHdr.marshal()
+ }
- vnetHdrBuf := vnetHdr.marshal()
- builder.Add(vnetHdrBuf)
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > e.writevMaxIovs {
+ numIovecs = e.writevMaxIovs
}
- for _, v := range pkt.Views() {
- builder.Add(v)
+ // Allocate small iovec arrays on the stack.
+ var iovecsArr [8]unix.Iovec
+ iovecs := iovecsArr[:0]
+ if numIovecs > len(iovecsArr) {
+ iovecs = make([]unix.Iovec, 0, numIovecs)
}
- return rawfile.NonBlockingWriteIovec(fd, builder.Build())
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
+ return rawfile.NonBlockingWriteIovec(fd, iovecs)
}
-func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) {
+func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) {
// Send a batch of packets through batchFD.
- mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
- for _, pkt := range batch {
- if e.hdrSize > 0 {
- e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
- }
+ mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts))
+ packets := 0
+ for packets < len(pkts) {
+ mmsgHdrs := mmsgHdrsStorage
+ batch := pkts[packets:]
+ syscallHeaderBytes := uintptr(0)
+ for _, pkt := range batch {
+ if e.hdrSize > 0 {
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
- var vnetHdrBuf []byte
- if e.gsoKind == stack.HWGSOSupported {
- vnetHdr := virtioNetHdr{}
- if pkt.GSOOptions.Type != stack.GSONone {
- vnetHdr.hdrLen = uint16(pkt.HeaderSize())
- if pkt.GSOOptions.NeedsCsum {
- vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
- vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
- vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
- }
- if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
- switch pkt.GSOOptions.Type {
- case stack.GSOTCPv4:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
- case stack.GSOTCPv6:
- vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
- default:
- panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ var vnetHdrBuf []byte
+ if e.gsoKind == stack.HWGSOSupported {
+ vnetHdr := virtioNetHdr{}
+ if pkt.GSOOptions.Type != stack.GSONone {
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
+ if pkt.GSOOptions.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
+ }
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ }
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
- vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
+ vnetHdrBuf = vnetHdr.marshal()
}
- vnetHdrBuf = vnetHdr.marshal()
- }
- var builder iovec.Builder
- builder.Add(vnetHdrBuf)
- for _, v := range pkt.Views() {
- builder.Add(v)
- }
- iovecs := builder.Build()
+ views := pkt.Views()
+ numIovecs := len(views)
+ if len(vnetHdrBuf) != 0 {
+ numIovecs++
+ }
+ if numIovecs > rawfile.MaxIovs {
+ numIovecs = rawfile.MaxIovs
+ }
+ if e.maxSyscallHeaderBytes != 0 {
+ syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec
+ if syscallHeaderBytes > e.maxSyscallHeaderBytes {
+ // We can't fit this packet into this call to sendmmsg().
+ // We could potentially do so if we reduced numIovecs
+ // further, but this might incur considerable extra
+ // copying. Leave it to the next batch instead.
+ break
+ }
+ }
- var mmsgHdr rawfile.MMsgHdr
- mmsgHdr.Msg.Iov = &iovecs[0]
- mmsgHdr.Msg.SetIovlen((len(iovecs)))
- mmsgHdrs = append(mmsgHdrs, mmsgHdr)
- }
+ // We can't easily allocate iovec arrays on the stack here since
+ // they will escape this loop iteration via mmsgHdrs.
+ iovecs := make([]unix.Iovec, 0, numIovecs)
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
+ for _, v := range views {
+ iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
+ }
- packets := 0
- for len(mmsgHdrs) > 0 {
- sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
- if err != nil {
- return packets, err
+ var mmsgHdr rawfile.MMsgHdr
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ mmsgHdr.Msg.SetIovlen(len(iovecs))
+ mmsgHdrs = append(mmsgHdrs, mmsgHdr)
+ }
+
+ if len(mmsgHdrs) == 0 {
+ // We can't fit batch[0] into a mmsghdr while staying under
+ // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the
+ // mmsghdr (by using writev) and re-buffer iovecs more aggressively
+ // if necessary (by using e.writevMaxIovs instead of
+ // rawfile.MaxIovs).
+ pkt := batch[0]
+ if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil {
+ return packets, err
+ }
+ packets++
+ } else {
+ for len(mmsgHdrs) > 0 {
+ sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
+ if err != nil {
+ return packets, err
+ }
+ packets += sent
+ mmsgHdrs = mmsgHdrs[sent:]
+ }
}
- packets += sent
- mmsgHdrs = mmsgHdrs[sent:]
}
return packets, nil
@@ -678,8 +756,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti
unix.SetNonblock(fd, true)
return &InjectableEndpoint{endpoint: endpoint{
- fds: []int{fd},
- mtu: mtu,
- caps: capabilities,
+ fds: []int{fd},
+ mtu: mtu,
+ caps: capabilities,
+ writevMaxIovs: rawfile.MaxIovs,
}}
}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index ba92aedbc..43fe57830 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -19,12 +19,66 @@
package rawfile
import (
+ "reflect"
"unsafe"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// SizeofIovec is the size of a unix.Iovec in bytes.
+const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
+
+// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
+// host system call in a single array.
+const MaxIovs = 1024
+
+// IovecFromBytes returns a unix.Iovec representing bs.
+//
+// Preconditions: len(bs) > 0.
+func IovecFromBytes(bs []byte) unix.Iovec {
+ iov := unix.Iovec{
+ Base: &bs[0],
+ }
+ iov.SetLen(len(bs))
+ return iov
+}
+
+func bytesFromIovec(iov unix.Iovec) (bs []byte) {
+ sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ sh.Data = uintptr(unsafe.Pointer(iov.Base))
+ sh.Len = int(iov.Len)
+ sh.Cap = int(iov.Len)
+ return
+}
+
+// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
+// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
+// max, AppendIovecFromBytes replaces the final iovec in iovs with one that
+// also includes the contents of bs. Note that this implies that
+// AppendIovecFromBytes is only usable when the returned iovec slice is used as
+// the source of a write.
+func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
+ if len(bs) == 0 {
+ return iovs
+ }
+ if len(iovs) < max {
+ return append(iovs, IovecFromBytes(bs))
+ }
+ iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
+ return iovs
+}
+
+// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
+type MMsgHdr struct {
+ Msg unix.Msghdr
+ Len uint32
+ _ [4]byte
+}
+
+// SizeofMMsgHdr is the size of a MMsgHdr in bytes.
+const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{})
+
// GetMTU determines the MTU of a network interface device.
func GetMTU(name string) (uint32, error) {
fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0)
@@ -137,13 +191,6 @@ func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
}
}
-// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
-type MMsgHdr struct {
- Msg unix.Msghdr
- Len uint32
- _ [4]byte
-}
-
// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
// and stores the received messages in a slice of MMsgHdr structures. If no data
// is available, it will block in a poll() syscall until the file descriptor
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 7656cca6a..4758a99ad 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -26,6 +26,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/errors/linuxerr",
"//pkg/log",
"//pkg/refs",
"//pkg/refsvfs2",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 36af2a029..f3444e8b5 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -18,6 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -88,12 +89,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
// Input validation.
if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
prefix := "tun"
@@ -108,7 +109,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error {
endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return linuxerr.EINVAL
}
d.endpoint = endpoint
@@ -159,7 +160,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, linuxerr.EINVAL
}
}
}
@@ -170,7 +171,7 @@ func (d *Device) Write(data []byte) (int64, error) {
endpoint := d.endpoint
d.mu.RUnlock()
if endpoint == nil {
- return 0, syserror.EBADFD
+ return 0, linuxerr.EBADFD
}
if !endpoint.IsAttached() {
return 0, syserror.EIO
@@ -207,6 +208,15 @@ func (d *Device) Write(data []byte) (int64, error) {
protocol = pktInfoHdr.Protocol()
case ethHdr != nil:
protocol = ethHdr.Type()
+ case d.flags.TUN:
+ // TUN interface with IFF_NO_PI enabled, thus
+ // we need to determine protocol from version field
+ version := data[0] >> 4
+ if version == 4 {
+ protocol = header.IPv4ProtocolNumber
+ } else if version == 6 {
+ protocol = header.IPv6ProtocolNumber
+ }
}
// Try to determine remote link address, default zero.
@@ -233,7 +243,7 @@ func (d *Device) Read() ([]byte, error) {
endpoint := d.endpoint
d.mu.RUnlock()
if endpoint == nil {
- return nil, syserror.EBADFD
+ return nil, linuxerr.EBADFD
}
for {
@@ -264,13 +274,6 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
vv.AppendView(buffer.View(hdr))
}
- // If the packet does not already have link layer header, and the route
- // does not exist, we can't compute it. This is possibly a raw packet, tun
- // device doesn't support this at the moment.
- if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 {
- return nil, false
- }
-
// Ethernet header (TAP only).
if d.flags.TAP {
// Add ethernet header if not provided.
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 0b51563cd..1261ad414 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -126,7 +126,7 @@ func (m *mockMulticastGroupProtocol) sendQueuedReports() {
// Precondition: m.mu must be read locked.
func (m *mockMulticastGroupProtocol) Enabled() bool {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled")
}
@@ -138,11 +138,11 @@ func (m *mockMulticastGroupProtocol) Enabled() bool {
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
}
@@ -155,11 +155,11 @@ func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (boo
// Precondition: m.mu must be locked.
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip.Error {
if m.mu.TryLock() {
- m.mu.Unlock()
+ m.mu.Unlock() // +checklocksforce: TryLock.
m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
if m.mu.TryRLock() {
- m.mu.RUnlock()
+ m.mu.RUnlock() // +checklocksforce: TryLock.
m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index bd63e0289..771b9173a 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -88,6 +88,7 @@ type testObject struct {
dataCalls int
controlCalls int
+ rawCalls int
}
// checkValues verifies that the transport protocol, data contents, src & dst
@@ -148,6 +149,10 @@ func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpi
t.controlCalls++
}
+func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+ t.rawCalls++
+}
+
// Attach is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) Attach(stack.NetworkDispatcher) {}
@@ -717,7 +722,10 @@ func TestReceive(t *testing.T) {
}
test.handlePacket(t, ep, &nic)
if nic.testObject.dataCalls != 1 {
- t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 1 {
+ t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
}
if got := stat.Value(); got != 1 {
t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got)
@@ -968,7 +976,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
+ t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 0 {
+ t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls)
}
// Send second segment.
@@ -977,7 +988,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
})
ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
+ t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls)
+ }
+ if nic.testObject.rawCalls != 1 {
+ t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls)
}
}
@@ -1310,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1351,7 +1365,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
ip.SetHeaderLength(header.IPv4MinimumSize - 1)
return hdr.View().ToVectorisedView()
@@ -1370,7 +1384,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
@@ -1388,7 +1402,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1430,7 +1444,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
Options: ipv4Options,
})
return hdr.View().ToVectorisedView()
@@ -1469,7 +1483,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
Protocol: transportProto,
TTL: ipv4.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
Options: ipv4Options,
})
vv := buffer.View(ip).ToVectorisedView()
@@ -1515,7 +1529,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1560,7 +1574,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier),
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return hdr.View().ToVectorisedView()
},
@@ -1595,7 +1609,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv6Addr,
})
return buffer.View(ip).ToVectorisedView()
},
@@ -1630,7 +1644,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
TransportProtocol: transportProto,
HopLimit: ipv6.DefaultTTL,
SrcAddr: src,
- DstAddr: header.IPv4Any,
+ DstAddr: remoteIPv4Addr,
})
return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 5f6b0c6af..2aa38eb98 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their
- // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize)
if !ok {
received.invalid.Increment()
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6bee55634..44c85bdb8 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -429,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -614,10 +614,6 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ipH.SetSourceAddress(r.LocalAddress())
}
- // Set the destination. If the packet already included a destination, it will
- // be part of the route anyways.
- ipH.SetDestinationAddress(r.RemoteAddress())
-
// Set the packet ID when zero.
if ipH.ID() == 0 {
// RFC 6864 section 4.3 mandates uniqueness of ID values for
@@ -720,7 +716,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -836,7 +833,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -855,10 +852,17 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) {
+ // Raw socket packets are delivered based solely on the transport protocol
+ // number. We only require that the packet be valid IPv4, and that they not
+ // be fragmented.
+ if !h.More() && h.FragmentOffset() == 0 {
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
+ }
+
pkt.NICID = e.nic.ID()
stats := e.stats
stats.ip.ValidPacketsReceived.Increment()
@@ -920,8 +924,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.ip.IPTablesInputDropped.Increment()
return
@@ -995,6 +998,9 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer)
// to do it here.
h.SetTotalLength(uint16(pkt.Data().Size() + len(h)))
h.SetFlagsFragmentOffset(0, 0)
+
+ // Now that the packet is reassembled, it can be sent to raw sockets.
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
}
stats.ip.PacketsDelivered.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 23fc94303..94caaae6c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) {
sent := e.stats.icmp.packetsSent
received := e.stats.icmp.packetsReceived
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set. See icmp/protocol.go:protocol.Parse for a full explanation.
+ // ICMP packets don't have their TransportHeader fields set. See
+ // icmp/protocol.go:protocol.Parse for a full explanation.
v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize)
if !ok {
received.invalid.Increment()
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index c2e9544c1..7c2a3e56b 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -90,6 +90,10 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st
return stack.TransportPacketHandled
}
+func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
+ // No-op.
+}
+
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 6103574f7..b1aec5312 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -344,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) {
e.mu.Lock()
defer e.mu.Unlock()
- e.mu.ndp.invalidateDefaultRouter(rtr)
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr})
}
// SetNDPConfigurations implements NDPEndpoint.
@@ -755,9 +758,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams,
// based on destination address and do not send the packet to link
// layer.
//
- // TODO(gvisor.dev/issue/170): We should do this for every
- // packet, rather than only NATted packets, but removing this check
- // short circuits broadcasts before they are sent out to other hosts.
+ // We should do this for every packet, rather than only NATted packets, but
+ // removing this check short circuits broadcasts before they are sent out to
+ // other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
@@ -928,10 +931,6 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ipH.SetSourceAddress(r.LocalAddress())
}
- // Set the destination. If the packet already included a destination, it will
- // be part of the route anyways.
- ipH.SetDestinationAddress(r.RemoteAddress())
-
// Populate the packet buffer's network header and don't allow an invalid
// packet to be sent.
//
@@ -991,7 +990,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError {
return nil
}
- ep.handleValidatedPacket(h, pkt)
+ // The packet originally arrived on e so provide its NIC as the input NIC.
+ ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
return nil
}
@@ -1104,7 +1104,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
}
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
// handleLocalPacket is like HandlePacket except it does not perform the
@@ -1123,10 +1123,14 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum
return
}
- e.handleValidatedPacket(h, pkt)
+ e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */)
}
-func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) {
+func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) {
+ // Raw socket packets are delivered based solely on the transport protocol
+ // number. We only require that the packet be valid IPv6.
+ e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt)
+
pkt.NICID = e.nic.ID()
stats := e.stats.ip
stats.ValidPacketsReceived.Increment()
@@ -1175,8 +1179,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer)
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok {
// iptables is telling us to drop the packet.
stats.IPTablesInputDropped.Increment()
return
@@ -1627,8 +1630,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) {
// TODO(b/169350103): add checks here after making sure we no longer receive
// an empty address.
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated)
}
@@ -1669,8 +1672,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre
// RemovePermanentAddress implements stack.AddressableEndpoint.
func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.mu.Lock()
+ defer e.mu.Unlock()
addressEndpoint := e.getAddressRLocked(addr)
if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() {
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 851cd6e75..8837d66d8 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -54,6 +54,11 @@ const (
// Advertisements, as a host.
defaultDiscoverDefaultRouters = true
+ // defaultDiscoverMoreSpecificRoutes is the default configuration for
+ // whether or not to discover more-specific routes from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverMoreSpecificRoutes = true
+
// defaultDiscoverOnLinkPrefixes is the default configuration for
// whether or not to discover on-link prefixes from incoming Router
// Advertisements' Prefix Information option, as a host.
@@ -78,13 +83,13 @@ const (
// we cannot have a negative delay.
minimumMaxRtrSolicitationDelay = 0
- // MaxDiscoveredDefaultRouters is the maximum number of discovered
- // default routers. The stack should stop discovering new routers after
- // discovering MaxDiscoveredDefaultRouters routers.
+ // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link
+ // routes. The stack should stop discovering new off-link routes after
+ // this limit is reached.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- MaxDiscoveredDefaultRouters = 10
+ MaxDiscoveredOffLinkRoutes = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
@@ -127,25 +132,17 @@ const (
// maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
// SLAAC address regenerations in response to an IPv6 endpoint-local conflict.
maxSLAACAddrLocalRegenAttempts = 10
-)
-var (
// MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
// Lifetime to update the valid lifetime of a generated address by
// SLAAC.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// Min = 2hrs.
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
// MaxDesyncFactor is the upper bound for the preferred lifetime's desync
// factor for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// Must be greater than 0.
//
// Max = 10m (from RFC 4941 section 5).
@@ -154,9 +151,6 @@ var (
// MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
// maximum preferred lifetime for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// This value guarantees that a temporary address is preferred for at
// least 1hr if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
@@ -164,9 +158,6 @@ var (
// MinMaxTempAddrValidLifetime is the minimum value allowed for the
// maximum valid lifetime for temporary SLAAC addresses.
//
- // This is exported as a variable (instead of a constant) so tests
- // can update it to a smaller value.
- //
// This value guarantees that a temporary address is valid for at least
// 2hrs if the SLAAC prefix is valid for at least that time.
MinMaxTempAddrValidLifetime = 2 * time.Hour
@@ -214,28 +205,23 @@ type NDPDispatcher interface {
// is also not permitted to call into the stack.
OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult)
- // OnDefaultRouterDiscovered is called when a new default router is
- // discovered. Implementations must return true if the newly discovered
- // router should be remembered.
+ // OnOffLinkRouteUpdated is called when an off-link route is updated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool
+ OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference)
- // OnDefaultRouterInvalidated is called when a discovered default router that
- // was remembered is invalidated.
+ // OnOffLinkRouteInvalidated is called when an off-link route is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address)
+ OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address)
// OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
- // Implementations must return true if the newly discovered on-link prefix
- // should be remembered.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool
+ OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet)
// OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
// was remembered is invalidated.
@@ -245,13 +231,11 @@ type NDPDispatcher interface {
OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet)
// OnAutoGenAddress is called when a new prefix with its autonomous address-
- // configuration flag set is received and SLAAC was performed. Implementations
- // may prevent the stack from assigning the address to the NIC by returning
- // false.
+ // configuration flag set is received and SLAAC was performed.
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
- OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix)
// OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC)
// is deprecated, but is still considered valid. Note, if an address is
@@ -373,12 +357,18 @@ type NDPConfigurations struct {
// DiscoverDefaultRouters determines whether or not default routers are
// discovered from Router Advertisements, as per RFC 4861 section 6. This
- // configuration is ignored if HandleRAs is false.
+ // configuration is ignored if RAs will not be processed (see HandleRAs).
DiscoverDefaultRouters bool
+ // DiscoverMoreSpecificRoutes determines whether or not more specific routes
+ // are discovered from Router Advertisements, as per RFC 4191. This
+ // configuration is ignored if RAs will not be processed (see HandleRAs).
+ DiscoverMoreSpecificRoutes bool
+
// DiscoverOnLinkPrefixes determines whether or not on-link prefixes are
// discovered from Router Advertisements' Prefix Information option, as per
- // RFC 4861 section 6. This configuration is ignored if HandleRAs is false.
+ // RFC 4861 section 6. This configuration is ignored if RAs will not be
+ // processed (see HandleRAs).
DiscoverOnLinkPrefixes bool
// AutoGenGlobalAddresses determines whether or not an IPv6 endpoint performs
@@ -429,6 +419,7 @@ func DefaultNDPConfigurations() NDPConfigurations {
MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
HandleRAs: defaultHandleRAs,
DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverMoreSpecificRoutes: defaultDiscoverMoreSpecificRoutes,
DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
@@ -469,6 +460,11 @@ type timer struct {
timer tcpip.Timer
}
+type offLinkRoute struct {
+ dest tcpip.Subnet
+ router tcpip.Address
+}
+
// ndpState is the per-Interface NDP state.
type ndpState struct {
// Do not allow overwriting this state.
@@ -483,8 +479,8 @@ type ndpState struct {
// The DAD timers to send the next NS message, or resolve the address.
dad ip.DAD
- // The default routers discovered through Router Advertisements.
- defaultRouters map[tcpip.Address]defaultRouterState
+ // The off-link routes discovered through Router Advertisements.
+ offLinkRoutes map[offLinkRoute]offLinkRouteState
// rtrSolicitTimer is the timer used to send the next router solicitation
// message.
@@ -512,10 +508,12 @@ type ndpState struct {
temporaryAddressDesyncFactor time.Duration
}
-// defaultRouterState holds data associated with a default router discovered by
+// offLinkRouteState holds data associated with an off-link route discovered by
// a Router Advertisement (RA).
-type defaultRouterState struct {
- // Job to invalidate the default router.
+type offLinkRouteState struct {
+ prf header.NDPRoutePreference
+
+ // Job to invalidate the route.
//
// Must not be nil.
invalidationJob *tcpip.Job
@@ -571,11 +569,11 @@ type slaacPrefixState struct {
// Must not be nil.
invalidationJob *tcpip.Job
- // Nonzero only when the address is not valid forever.
- validUntil tcpip.MonotonicTime
+ // nil iff the address is valid forever.
+ validUntil *tcpip.MonotonicTime
- // Nonzero only when the address is not preferred forever.
- preferredUntil tcpip.MonotonicTime
+ // nil iff the address is preferred forever.
+ preferredUntil *tcpip.MonotonicTime
// State associated with the stable address generated for the prefix.
stableAddr struct {
@@ -733,30 +731,22 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// Is the IPv6 endpoint configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
- rtr, ok := ndp.defaultRouters[ip]
- rl := ra.RouterLifetime()
- switch {
- case !ok && rl != 0:
- // This is a new default router we are discovering.
+ prf := ra.DefaultRouterPreference()
+ if prf == header.ReservedRoutePreference {
+ // As per RFC 4191 section 2.2,
//
- // Only remember it if we currently know about less than
- // MaxDiscoveredDefaultRouters routers.
- if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
- ndp.rememberDefaultRouter(ip, rl)
- }
-
- case ok && rl != 0:
- // This is an already discovered default router. Update
- // the invalidation job.
- rtr.invalidationJob.Cancel()
- rtr.invalidationJob.Schedule(rl)
- ndp.defaultRouters[ip] = rtr
-
- case ok && rl == 0:
- // We know about the router but it is no longer to be
- // used as a default router so invalidate it.
- ndp.invalidateDefaultRouter(ip)
+ // Prf (Default Router Preference)
+ //
+ // If the Reserved (10) value is received, the receiver MUST treat the
+ // value as if it were (00).
+ //
+ // Note that the value 00 is the medium (default) router preference value.
+ prf = header.MediumRoutePreference
}
+
+ // We represent default routers with a default (off-link) route through the
+ // router.
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf)
}
// TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
@@ -808,61 +798,107 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if opt.AutonomousAddressConfigurationFlag() {
ndp.handleAutonomousPrefixInformation(opt)
}
+
+ case header.NDPRouteInformation:
+ if !ndp.configs.DiscoverMoreSpecificRoutes {
+ continue
+ }
+
+ dest, err := opt.Prefix()
+ if err != nil {
+ panic(fmt.Sprintf("%T.Prefix(): %s", opt, err))
+ }
+
+ prf := opt.RoutePreference()
+ if prf == header.ReservedRoutePreference {
+ // As per RFC 4191 section 2.3,
+ //
+ // Prf (Route Preference)
+ // 2-bit signed integer. The Route Preference indicates
+ // whether to prefer the router associated with this prefix
+ // over others, when multiple identical prefixes (for
+ // different routers) have been received. If the Reserved
+ // (10) value is received, the Route Information Option MUST
+ // be ignored.
+ continue
+ }
+
+ ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: dest, router: ip}, opt.RouteLifetime(), prf)
}
// TODO(b/141556115): Do (MTU) Parameter Discovery.
}
}
-// invalidateDefaultRouter invalidates a discovered default router.
+// invalidateOffLinkRoute invalidates a discovered off-link route.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
- rtr, ok := ndp.defaultRouters[ip]
-
- // Is the router still discovered?
+func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) {
+ state, ok := ndp.offLinkRoutes[route]
if !ok {
- // ...Nope, do nothing further.
return
}
- rtr.invalidationJob.Cancel()
- delete(ndp.defaultRouters, ip)
+ state.invalidationJob.Cancel()
+ delete(ndp.offLinkRoutes, route)
- // Let the integrator know a discovered default router is invalidated.
+ // Let the integrator know a discovered off-link route is invalidated.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
+ ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router)
}
}
-// rememberDefaultRouter remembers a newly discovered default router with IPv6
-// link-local address ip with lifetime rl.
+// handleOffLinkRouteDiscovery handles the discovery of an off-link route.
//
-// The router identified by ip MUST NOT already be known by the IPv6 endpoint.
-//
-// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+// Precondition: ndp.ep.mu must be locked.
+func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) {
ndpDisp := ndp.ep.protocol.options.NDPDisp
if ndpDisp == nil {
return
}
- // Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
- // Informed by the integrator to not remember the router, do
- // nothing further.
- return
- }
+ state, ok := ndp.offLinkRoutes[route]
+ switch {
+ case !ok && lifetime != 0:
+ // This is a new route we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOffLinkRoutes routers.
+ if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes {
+ // Inform the integrator when we discovered an off-link route.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+
+ state := offLinkRouteState{
+ prf: prf,
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
+ ndp.invalidateOffLinkRoute(route)
+ }),
+ }
- state := defaultRouterState{
- invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
- ndp.invalidateDefaultRouter(ip)
- }),
- }
+ state.invalidationJob.Schedule(lifetime)
+
+ ndp.offLinkRoutes[route] = state
+ }
+
+ case ok && lifetime != 0:
+ // This is an already discovered off-link route. Update the lifetime.
+ state.invalidationJob.Cancel()
+ state.invalidationJob.Schedule(lifetime)
- state.invalidationJob.Schedule(rl)
+ if prf != state.prf {
+ state.prf = prf
+
+ // Inform the integrator about route preference updates.
+ ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf)
+ }
- ndp.defaultRouters[ip] = state
+ ndp.offLinkRoutes[route] = state
+
+ case ok && lifetime == 0:
+ // The already discovered off-link route is no longer considered valid so we
+ // invalidate it immediately.
+ ndp.invalidateOffLinkRoute(route)
+ }
}
// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
@@ -878,11 +914,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
- // Informed by the integrator to not remember the prefix, do
- // nothing further.
- return
- }
+ ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix)
state := onLinkPrefixState{
invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
@@ -1055,7 +1087,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// The time an address is preferred until is needed to properly generate the
// address.
if pl < header.NDPInfiniteLifetime {
- state.preferredUntil = now.Add(pl)
+ t := now.Add(pl)
+ state.preferredUntil = &t
}
if !ndp.generateSLAACAddr(prefix, &state) {
@@ -1073,7 +1106,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
if vl < header.NDPInfiniteLifetime {
state.invalidationJob.Schedule(vl)
- state.validUntil = now.Add(vl)
+ t := now.Add(vl)
+ state.validUntil = &t
}
// If the address is assigned (DAD resolved), generate a temporary address.
@@ -1096,16 +1130,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
- // Informed by the integrator not to add the address.
- return nil
- }
-
addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated)
if err != nil {
panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err))
}
+ ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr)
+
return addressEndpoint
}
@@ -1181,7 +1212,8 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
state.stableAddr.localGenerationFailures++
}
- if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, ndp.ep.protocol.stack.Clock().NowMonotonic().Sub(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil {
+ deprecated := state.preferredUntil != nil && !state.preferredUntil.After(ndp.ep.protocol.stack.Clock().NowMonotonic())
+ if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, deprecated); addressEndpoint != nil {
state.stableAddr.addressEndpoint = addressEndpoint
state.generationAttempts++
return true
@@ -1242,7 +1274,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// address is the lower of the valid lifetime of the stable address or the
// maximum temporary address valid lifetime.
vl := ndp.configs.MaxTempAddrValidLifetime
- if prefixState.validUntil != (tcpip.MonotonicTime{}) {
+ if prefixState.validUntil != nil {
if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
vl = prefixVL
}
@@ -1258,7 +1290,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
// maximum temporary address preferred lifetime - the temporary address desync
// factor.
pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
- if prefixState.preferredUntil != (tcpip.MonotonicTime{}) {
+ if prefixState.preferredUntil != nil {
if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
// Respect the preferred lifetime of the prefix, as per RFC 4941 section
// 3.3 step 4.
@@ -1400,9 +1432,10 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
if !deprecated {
prefixState.deprecationJob.Schedule(pl)
}
- prefixState.preferredUntil = now.Add(pl)
+ t := now.Add(pl)
+ prefixState.preferredUntil = &t
} else {
- prefixState.preferredUntil = tcpip.MonotonicTime{}
+ prefixState.preferredUntil = nil
}
// As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
@@ -1420,14 +1453,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// Handle the infinite valid lifetime separately as we do not schedule a
// job in this case.
prefixState.invalidationJob.Cancel()
- prefixState.validUntil = tcpip.MonotonicTime{}
+ prefixState.validUntil = nil
} else {
var effectiveVl time.Duration
var rl time.Duration
// If the prefix was originally set to be valid forever, assume the
// remaining time to be the maximum possible value.
- if prefixState.validUntil == (tcpip.MonotonicTime{}) {
+ if prefixState.validUntil == nil {
rl = header.NDPInfiniteLifetime
} else {
rl = prefixState.validUntil.Sub(now)
@@ -1442,7 +1475,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
if effectiveVl != 0 {
prefixState.invalidationJob.Cancel()
prefixState.invalidationJob.Schedule(effectiveVl)
- prefixState.validUntil = now.Add(effectiveVl)
+ t := now.Add(effectiveVl)
+ prefixState.validUntil = &t
}
}
@@ -1462,8 +1496,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// maximum temporary address valid lifetime. Note, the valid lifetime of a
// temporary address is relative to the address's creation time.
validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
- if prefixState.validUntil != (tcpip.MonotonicTime{}) && validUntil.Sub(prefixState.validUntil) > 0 {
- validUntil = prefixState.validUntil
+ if prefixState.validUntil != nil && prefixState.validUntil.Before(validUntil) {
+ validUntil = *prefixState.validUntil
}
// If the address is no longer valid, invalidate it immediately. Otherwise,
@@ -1482,14 +1516,15 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat
// desync factor. Note, the preferred lifetime of a temporary address is
// relative to the address's creation time.
preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
- if prefixState.preferredUntil != (tcpip.MonotonicTime{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
- preferredUntil = prefixState.preferredUntil
+ if prefixState.preferredUntil != nil && prefixState.preferredUntil.Before(preferredUntil) {
+ preferredUntil = *prefixState.preferredUntil
}
// If the address is no longer preferred, deprecate it immediately.
// Otherwise, schedule the deprecation job again.
newPreferredLifetime := preferredUntil.Sub(now)
tempAddrState.deprecationJob.Cancel()
+
if newPreferredLifetime <= 0 {
ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
} else {
@@ -1679,12 +1714,12 @@ func (ndp *ndpState) cleanupState() {
panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
}
- for router := range ndp.defaultRouters {
- ndp.invalidateDefaultRouter(router)
+ for route := range ndp.offLinkRoutes {
+ ndp.invalidateOffLinkRoute(route)
}
- if got := len(ndp.defaultRouters); got != 0 {
- panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ if got := len(ndp.offLinkRoutes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got))
}
ndp.dhcpv6Configuration = 0
@@ -1847,21 +1882,19 @@ func (ndp *ndpState) stopSolicitingRouters() {
}
func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) {
- if ndp.defaultRouters != nil {
+ if ndp.offLinkRoutes != nil {
panic("attempted to initialize NDP state twice")
}
ndp.ep = ep
ndp.configs = ep.protocol.options.NDPConfigs
ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions)
- ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState)
ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
- if MaxDesyncFactor != 0 {
- ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor)))
- }
+ ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor)))
}
func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 3438deb79..f0186c64e 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -42,24 +42,21 @@ type testNDPDispatcher struct {
func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
+func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) {
t.addr = addr
- return true
}
-func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) {
+func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) {
t.addr = addr
}
-func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {
}
-func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return false
+func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {
@@ -96,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
ipv6EP := ep.(*endpoint)
ipv6EP.mu.Lock()
- ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour)
+ ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference)
ipv6EP.mu.Unlock()
if ndpDisp.addr != lladdr1 {
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index b7f6d52ae..fe98a52af 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -12,6 +12,7 @@ go_library(
deps = [
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 854d6a6ba..fb8ef1ee2 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
@@ -122,7 +123,7 @@ type deviceToDest map[tcpip.NICID]destToCounter
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (dd deviceToDest) isAvailable(res Reservation) bool {
+func (dd deviceToDest) isAvailable(res Reservation, portSpecified bool) bool {
flagBits := res.Flags.Bits()
if res.BindToDevice == 0 {
intersection := FlagMask
@@ -138,6 +139,9 @@ func (dd deviceToDest) isAvailable(res Reservation) bool {
return false
}
}
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
return true
}
@@ -146,16 +150,26 @@ func (dd deviceToDest) isAvailable(res Reservation) bool {
if dests, ok := dd[0]; ok {
var count int
intersection, count = dests.intersectionFlags(res)
- if count > 0 && intersection&flagBits == 0 {
- return false
+ if count > 0 {
+ if intersection&flagBits == 0 {
+ return false
+ }
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
}
}
if dests, ok := dd[res.BindToDevice]; ok {
flags, count := dests.intersectionFlags(res)
intersection &= flags
- if count > 0 && intersection&flagBits == 0 {
- return false
+ if count > 0 {
+ if intersection&flagBits == 0 {
+ return false
+ }
+ if !portSpecified && res.Transport == header.TCPProtocolNumber {
+ return false
+ }
}
}
@@ -168,12 +182,12 @@ type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (ad addrToDevice) isAvailable(res Reservation) bool {
+func (ad addrToDevice) isAvailable(res Reservation, portSpecified bool) bool {
if res.Addr == anyIPAddress {
// If binding to the "any" address then check that there are no
// conflicts with all addresses.
for _, devices := range ad {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
@@ -182,14 +196,14 @@ func (ad addrToDevice) isAvailable(res Reservation) bool {
// Check that there is no conflict with the "any" address.
if devices, ok := ad[anyIPAddress]; ok {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
// Check that this is no conflict with the provided address.
if devices, ok := ad[res.Addr]; ok {
- if !devices.isAvailable(res) {
+ if !devices.isAvailable(res, portSpecified) {
return false
}
}
@@ -310,7 +324,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// If a port is specified, just try to reserve it for all network
// protocols.
if res.Port != 0 {
- if !pm.reserveSpecificPortLocked(res) {
+ if !pm.reserveSpecificPortLocked(res, true /* portSpecified */) {
return 0, &tcpip.ErrPortInUse{}
}
if testPort != nil {
@@ -330,7 +344,7 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// A port wasn't specified, so try to find one.
return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) {
res.Port = p
- if !pm.reserveSpecificPortLocked(res) {
+ if !pm.reserveSpecificPortLocked(res, false /* portSpecified */) {
return false, nil
}
if testPort != nil {
@@ -350,12 +364,12 @@ func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort Por
// reserveSpecificPortLocked tries to reserve the given port on all given
// protocols.
-func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool {
+func (pm *PortManager) reserveSpecificPortLocked(res Reservation, portSpecified bool) bool {
// Make sure the port is available.
for _, network := range res.Networks {
desc := portDescriptor{network, res.Transport, res.Port}
if addrs, ok := pm.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(res) {
+ if !addrs.isAvailable(res, portSpecified) {
return false
}
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 395ff9a07..e0847e58a 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -95,7 +95,7 @@ go_library(
go_test(
name = "stack_x_test",
- size = "medium",
+ size = "small",
srcs = [
"addressable_endpoint_state_test.go",
"ndp_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index ce9cebdaa..ae0bb4ace 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -249,7 +249,7 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
// or we are adding a new temporary or permanent address.
//
// The address MUST be write locked at this point.
- defer addrState.mu.Unlock()
+ defer addrState.mu.Unlock() // +checklocksforce
if permanent {
if addrState.mu.kind.IsPermanent() {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index f7fbcbaa7..068dab7ce 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -35,7 +35,6 @@ import (
// Currently, only TCP tracking is supported.
// Our hash table has 16K buckets.
-// TODO(gvisor.dev/issue/170): These should be tunable.
const numBuckets = 1 << 14
// Direction of the tuple.
@@ -165,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
// Update the state of tcb. tcb assumes it's always initialized on the
// client. However, we only need to know whether the connection is
// established or not, so the client/server distinction isn't important.
- // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
- // other tcp states.
if cn.tcb.IsEmpty() {
cn.tcb.Init(tcpHeader)
} else if hook == cn.tcbHook {
@@ -246,8 +243,7 @@ func (ct *ConnTrack) init() {
// connFor gets the conn for pkt if it exists, or returns nil
// if it does not. It returns an error when pkt does not contain a valid TCP
// header.
-// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
-// other transport protocols.
+// TODO(gvisor.dev/issue/6168): Support UDP.
func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
tid, err := packetToTupleID(pkt)
if err != nil {
@@ -367,7 +363,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Unlocking can happen in any order.
ct.buckets[tupleBucket].mu.Unlock()
if tupleBucket != replyBucket {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
}
@@ -385,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
- // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ // TODO(gvisor.dev/issue/6168): Support UDP.
if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -409,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
+ var newAddr tcpip.Address
+ var newPort uint16
+
+ updateSRCFields := false
+
switch hook {
case Prerouting, Output:
if conn.manip == manipDestination {
switch dir {
case dirOriginal:
- tcpHeader.SetDestinationPort(conn.reply.srcPort)
- netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ newPort = conn.reply.srcPort
+ newAddr = conn.reply.srcAddr
case dirReply:
- tcpHeader.SetSourcePort(conn.original.dstPort)
- netHeader.SetSourceAddress(conn.original.dstAddr)
+ newPort = conn.original.dstPort
+ newAddr = conn.original.dstAddr
+
+ updateSRCFields = true
}
pkt.NatDone = true
}
@@ -426,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
if conn.manip == manipSource {
switch dir {
case dirOriginal:
- tcpHeader.SetSourcePort(conn.reply.dstPort)
- netHeader.SetSourceAddress(conn.reply.dstAddr)
+ newPort = conn.reply.dstPort
+ newAddr = conn.reply.dstAddr
+
+ updateSRCFields = true
case dirReply:
- tcpHeader.SetDestinationPort(conn.original.srcPort)
- netHeader.SetDestinationAddress(conn.original.srcAddr)
+ newPort = conn.original.srcPort
+ newAddr = conn.original.srcAddr
}
pkt.NatDone = true
}
@@ -441,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool {
return false
}
+ fullChecksum := false
+ updatePseudoHeader := false
switch hook {
case Prerouting, Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
- tcpHeader.SetChecksum(0)
- length := uint16(len(tcpHeader) + pkt.Data().Size())
- xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
- tcpHeader.SetChecksum(xsum)
+ updatePseudoHeader = true
} else if r.RequiresTXTransportChecksum() {
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ fullChecksum = true
+ updatePseudoHeader = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
+ rewritePacket(
+ netHeader,
+ tcpHeader,
+ updateSRCFields,
+ fullChecksum,
+ updatePseudoHeader,
+ newPort,
+ newAddr,
+ )
// Update the state of tcb.
- // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
- // other tcp states.
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -544,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int {
// reapUnused returns the next bucket that should be checked and the time after
// which it should be called again.
func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
- // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
- // it is in Linux via sysctl.
const fractionPerReaping = 128
const maxExpiredPct = 50
const maxFullTraversal = 60 * time.Second
@@ -623,7 +626,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
// Don't re-unlock if both tuples are in the same bucket.
if differentBuckets {
- ct.buckets[replyBucket].mu.Unlock()
+ ct.buckets[replyBucket].mu.Unlock() // +checklocksforce
}
return true
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 0a26f6dd8..f152c0d83 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -268,10 +268,6 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
-// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from
-// which address can be gathered. Currently, address is only needed for
-// prerouting.
-//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 2812c89aa..96cc899bb 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre
// destination port/IP. Outgoing packets are redirected to the loopback device,
// and incoming packets are redirected to the incoming interface (rather than
// forwarded).
-//
-// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
-// them.
type RedirectTarget struct {
// Port indicates port used to redirect. It is immutable.
Port uint16
@@ -100,9 +97,6 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for Prerouting and calls pkt.Clone(), neither
-// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
@@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r
panic("redirect target is supported only on output and prerouting hooks")
}
- // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
- // we need to change dest address (for OUTPUT chain) or ports.
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.Port)
- // Calculate UDP checksum and set it.
if hook == Output {
- udpHeader.SetChecksum(0)
- netHeader := pkt.Network()
- netHeader.SetDestinationAddress(address)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ udpHeader,
+ false, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ rt.Port,
+ address,
+ )
+ } else {
+ udpHeader.SetDestinationPort(rt.Port)
}
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
- udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetChecksum(0)
- udpHeader.SetSourcePort(st.Port)
- netHeader := pkt.Network()
- netHeader.SetSourceAddress(st.Addr)
-
// Only calculate the checksum if offloading isn't supported.
- if r.RequiresTXTransportChecksum() {
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
- xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
- }
+ requiresChecksum := r.RequiresTXTransportChecksum()
+ rewritePacket(
+ pkt.Network(),
+ header.UDP(pkt.TransportHeader().View()),
+ true, /* updateSRCFields */
+ requiresChecksum,
+ requiresChecksum,
+ st.Port,
+ st.Addr,
+ )
- // After modification, IPv4 packets need a valid checksum.
- if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
- }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
@@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou
return RuleAccept, 0
}
+
+func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) {
+ if updateSRCFields {
+ if fullChecksum {
+ t.SetSourcePortWithChecksumUpdate(newPort)
+ } else {
+ t.SetSourcePort(newPort)
+ }
+ } else {
+ if fullChecksum {
+ t.SetDestinationPortWithChecksumUpdate(newPort)
+ } else {
+ t.SetDestinationPort(newPort)
+ }
+ }
+
+ if updatePseudoHeader {
+ var oldAddr tcpip.Address
+ if updateSRCFields {
+ oldAddr = n.SourceAddress()
+ } else {
+ oldAddr = n.DestinationAddress()
+ }
+
+ t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum)
+ }
+
+ if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok {
+ if updateSRCFields {
+ checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr)
+ } else {
+ checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr)
+ }
+ } else if updateSRCFields {
+ n.SetSourceAddress(newAddr)
+ } else {
+ n.SetDestinationAddress(newAddr)
+ }
+}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 93592e7f5..66e5f22ac 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -242,7 +242,6 @@ type IPHeaderFilter struct {
func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
- // TODO(gvisor.dev/issue/170): Support other filter fields.
transProto tcpip.TransportProtocolNumber
dstAddr tcpip.Address
srcAddr tcpip.Address
@@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa
return true
case Postrouting:
- // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING.
return true
default:
panic(fmt.Sprintf("unknown hook: %d", hook))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 133bacdd0..ca2250ad6 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -52,17 +52,6 @@ const (
linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
defaultPrefixLen = 128
-
- // Extra time to use when waiting for an async event to occur.
- defaultAsyncPositiveEventTimeout = 10 * time.Second
-
- // Extra time to use when waiting for an async event to not occur.
- //
- // Since a negative check is used to make sure an event did not happen, it is
- // okay to use a smaller timeout compared to the positive case since execution
- // stall in regards to the monotonic clock will not affect the expected
- // outcome.
- defaultAsyncNegativeEventTimeout = time.Second
)
var (
@@ -112,11 +101,13 @@ type ndpDADEvent struct {
res stack.DADResult
}
-type ndpRouterEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- // true if router was discovered, false if invalidated.
- discovered bool
+type ndpOffLinkRouteEvent struct {
+ nicID tcpip.NICID
+ subnet tcpip.Subnet
+ router tcpip.Address
+ prf header.NDPRoutePreference
+ // true if route was updated, false if invalidated.
+ updated bool
}
type ndpPrefixEvent struct {
@@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct {
eventType ndpAutoGenAddrEventType
}
+func (e ndpAutoGenAddrEvent) String() string {
+ return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType)
+}
+
type ndpRDNSS struct {
addrs []tcpip.Address
lifetime time.Duration
@@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
// related events happen for test purposes.
type ndpDispatcher struct {
dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
+ offLinkRouteC chan ndpOffLinkRouteEvent
prefixC chan ndpPrefixEvent
- rememberPrefix bool
autoGenAddrC chan ndpAutoGenAddrEvent
rdnssC chan ndpRDNSSEvent
dnsslC chan ndpDNSSLEvent
@@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add
}
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered.
-func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated.
+func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) {
+ if c := n.offLinkRouteC; c != nil {
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
+ prf,
true,
}
}
-
- return n.rememberRouter
}
-// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated.
-func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
- if c := n.routerC; c != nil {
- c <- ndpRouterEvent{
+// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated.
+func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) {
+ if c := n.offLinkRouteC; c != nil {
+ var prf header.NDPRoutePreference
+ c <- ndpOffLinkRouteEvent{
nicID,
- addr,
+ subnet,
+ router,
+ prf,
false,
}
}
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered.
-func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) {
if c := n.prefixC; c != nil {
c <- ndpPrefixEvent{
nicID,
@@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip
true,
}
}
-
- return n.rememberPrefix
}
// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated.
@@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi
}
}
-func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
if c := n.autoGenAddrC; c != nil {
c <- ndpAutoGenAddrEvent{
nicID,
@@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi
newAddr,
}
}
- return true
}
func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
@@ -1039,9 +1032,12 @@ func TestSetNDPConfigurations(t *testing.T) {
}
}
-// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
-// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+// raBuf returns a valid NDP Router Advertisement with options, router
+// preference and DHCPv6 configurations specified.
+func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ const flagsByte = 1
+ const routerLifetimeOffset = 2
+
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -1050,19 +1046,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
raPayload := pkt.MessageBody()
ra := header.NDPRouterAdvert(raPayload)
// Populate the Router Lifetime.
- binary.BigEndian.PutUint16(raPayload[2:], rl)
+ binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl)
// Populate the Managed Address flag field.
if managedAddress {
- // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
- // of the RA payload.
- raPayload[1] |= 1 << 7
+ // The Managed Addresses flag field is the 7th bit of the flags byte.
+ raPayload[flagsByte] |= 1 << 7
}
// Populate the Other Configurations flag field.
if otherConfigurations {
- // The Other Configurations flag field is the 6th bit of byte #1
- // (0-indexing) of the RA payload.
- raPayload[1] |= 1 << 6
+ // The Other Configurations flag field is the 6th bit of the flags byte.
+ raPayload[flagsByte] |= 1 << 6
}
+ // The Prf field is held in the flags byte.
+ raPayload[flagsByte] |= byte(prf) << 3
opts := ra.Options()
opts.Serialize(optSer)
pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
@@ -1090,7 +1086,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+ return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer)
}
// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
@@ -1098,18 +1094,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
- return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer {
+ return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{})
}
// raBuf returns a valid NDP Router Advertisement.
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
+func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
+// raBufWithPrf returns a valid NDP Router Advertisement with a preference.
+//
+// Note, raBufWithPrf does not populate any of the RA fields other than the
+// Router Lifetime and Default Router Preference fields.
+func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{})
+}
+
// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
// Information option.
//
@@ -1148,6 +1152,39 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on
})
}
+// raBufWithRIO returns a valid NDP Router Advertisement with a single Route
+// Information option.
+//
+// All fields in the RA will be zero except the RIO option.
+func raBufWithRIO(t *testing.T, ip tcpip.Address, prefix tcpip.AddressWithPrefix, lifetimeSeconds uint32, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ // buf will hold the route information option after the Type and Length
+ // fields.
+ //
+ // 2.3. Route Information Option
+ //
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Type | Length | Prefix Length |Resvd|Prf|Resvd|
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Route Lifetime |
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ // | Prefix (Variable Length) |
+ // . .
+ // . .
+ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ var buf [22]byte
+ buf[0] = uint8(prefix.PrefixLen)
+ buf[1] = byte(prf) << 3
+ binary.BigEndian.PutUint32(buf[2:], lifetimeSeconds)
+ if n := copy(buf[6:], prefix.Address); n != len(prefix.Address) {
+ t.Fatalf("got copy(...) = %d, want = %d", n, len(prefix.Address))
+ }
+ return raBufWithOpts(ip, 0 /* router lifetime */, header.NDPOptionsSerializer{
+ header.NDPRouteInformation(buf[:]),
+ })
+}
+
func TestDynamicConfigurationsDisabled(t *testing.T) {
const (
nicID = 1
@@ -1169,7 +1206,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
config: func(enable bool) ipv6.NDPConfigurations {
return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable}
},
- ra: raBuf(llAddr2, 1000),
+ ra: raBufSimple(llAddr2, 1000),
},
{
name: "No Prefix Discovery",
@@ -1205,9 +1242,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- prefixC: make(chan ndpPrefixEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
ndpConfigs := test.config(enable)
ndpConfigs.HandleRAs = handle
@@ -1277,8 +1314,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) {
t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpectedly discovered a router when configured not to: %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e)
default:
}
select {
@@ -1304,10 +1341,8 @@ func boolToUint64(v bool) uint64 {
return 0
}
-// Check e to make sure that the event is for addr on nic with ID 1, and the
-// discovered flag set to discovered.
-func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
- return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string {
+ return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: subnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e))
}
func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) {
@@ -1340,167 +1375,176 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b
}
}
-// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered router when the dispatcher asks it not to.
-func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
+func TestOffLinkRouteDiscovery(t *testing.T) {
+ const nicID = 1
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ moreSpecificPrefix := tcpip.AddressWithPrefix{Address: testutil.MustParse6("a00::"), PrefixLen: 16}
+ tests := []struct {
+ name string
- // Receive an RA for a router we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected router discovery event")
- }
+ discoverDefaultRouters bool
+ discoverMoreSpecificRoutes bool
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the router in the first place.
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.routerC:
- t.Fatal("should not have received any router events")
- default:
+ dest tcpip.Subnet
+ ra func(*testing.T, tcpip.Address, uint16, header.NDPRoutePreference) *stack.PacketBuffer
+ }{
+ {
+ name: "Default router discovery",
+ discoverDefaultRouters: true,
+ discoverMoreSpecificRoutes: false,
+ dest: header.IPv6EmptySubnet,
+ ra: func(_ *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBufWithPrf(router, lifetimeSeconds, prf)
+ },
+ },
+ {
+ name: "More-specific route discovery",
+ discoverDefaultRouters: false,
+ discoverMoreSpecificRoutes: true,
+ dest: moreSpecificPrefix.Subnet(),
+ ra: func(t *testing.T, router tcpip.Address, lifetimeSeconds uint16, prf header.NDPRoutePreference) *stack.PacketBuffer {
+ return raBufWithRIO(t, router, moreSpecificPrefix, uint32(lifetimeSeconds), prf)
+ },
+ },
}
-}
-func TestRouterDiscovery(t *testing.T) {
- testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: handleRAs,
- DiscoverDefaultRouters: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
+ ndpDisp := ndpDispatcher{
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: handleRAs,
+ DiscoverDefaultRouters: test.discoverDefaultRouters,
+ DiscoverMoreSpecificRoutes: test.discoverMoreSpecificRoutes,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- expectRouterEvent := func(addr tcpip.Address, discovered bool) {
- t.Helper()
+ expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) {
+ t.Helper()
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, discovered); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, updated); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
}
- default:
- t.Fatal("expected router discovery event")
- }
- }
- expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
- t.Helper()
+ expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
- clock.Advance(timeout)
- select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, addr, false); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ clock.Advance(timeout)
+ select {
+ case e := <-ndpDisp.offLinkRouteC:
+ var prf header.NDPRoutePreference
+ if diff := checkOffLinkRouteEvent(e, nicID, test.dest, addr, prf, false); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("timed out waiting for router discovery event")
+ }
}
- default:
- t.Fatal("timed out waiting for router discovery event")
- }
- }
- if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
- t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
- }
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
+ t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
+ }
- // Rx an RA from lladdr2 with zero lifetime. It should not be
- // remembered.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("unexpectedly discovered a router with 0 lifetime")
- default:
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Rx an RA from lladdr2 with a huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with 0 lifetime")
+ default:
+ }
- // Rx an RA from another router (lladdr3) with non-zero lifetime.
- const l3LifetimeSeconds = 6
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
- expectRouterEvent(llAddr3, true)
+ // Discover an off-link route through llAddr2.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.ReservedRoutePreference))
+ if test.discoverMoreSpecificRoutes {
+ // The reserved value is considered invalid with more-specific route
+ // discovery so we inject the same packet but with the default
+ // (medium) preference value.
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("unexpectedly updated an off-link route with a reserved preference value")
+ default:
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
+ }
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime and
+ // non-default preference value.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr3, l3LifetimeSeconds, header.HighRoutePreference))
+ expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true)
+
+ // Rx an RA from lladdr2 with lesser lifetime and default (medium)
+ // preference value.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.MediumRoutePreference))
+ select {
+ case <-ndpDisp.offLinkRouteC:
+ t.Fatal("should not receive a off-link route event when updating lifetimes for known routers")
+ default:
+ }
- // Rx an RA from lladdr2 with lesser lifetime.
- const l2LifetimeSeconds = 2
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
- select {
- case <-ndpDisp.routerC:
- t.Fatal("Should not receive a router event when updating lifetimes for known routers")
- default:
- }
+ // Rx an RA from lladdr2 with a different preference.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, l2LifetimeSeconds, header.LowRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true)
- // Wait for lladdr2's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
-
- // Rx an RA from lladdr2 with huge lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
- expectRouterEvent(llAddr2, true)
-
- // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
- expectRouterEvent(llAddr2, false)
-
- // Wait for lladdr3's router invalidation job to execute. The lifetime
- // of the router should have been updated to the most recent (smaller)
- // lifetime.
- //
- // Wait for the normal lifetime plus an extra bit for the
- // router to get invalidated. If we don't get an invalidation
- // event after this time, then something is wrong.
- expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
- })
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 1000, header.MediumRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, test.ra(t, llAddr2, 0, header.MediumRoutePreference))
+ expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second)
+ })
+ })
+ }
}
// TestRouterDiscoveryMaxRouters tests that only
-// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered.
+// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered.
func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ const nicID = 1
+
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1513,23 +1557,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
})},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
// Receive an RA from 2 more than the max number of discovered routers.
- for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ {
+ for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ {
linkAddr := []byte{2, 2, 3, 4, 5, 0}
linkAddr[5] = byte(i)
llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
- e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5))
- if i <= ipv6.MaxDiscoveredDefaultRouters {
+ if i <= ipv6.MaxDiscoveredOffLinkRoutes {
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr, true); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr, header.MediumRoutePreference, true); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected router discovery event")
@@ -1537,7 +1581,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) {
} else {
select {
- case <-ndpDisp.routerC:
+ case <-ndpDisp.offLinkRouteC:
t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
default:
}
@@ -1551,54 +1595,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st
return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
}
-// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
-// remember a discovered on-link prefix when the dispatcher asks it not to.
-func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
- prefix, subnet, _ := prefixSubnetAddr(0, "")
-
- ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- clock := faketime.NewManualClock()
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- DiscoverOnLinkPrefixes: true,
- },
- NDPDisp: &ndpDisp,
- })},
- Clock: clock,
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
-
- // Receive an RA with prefix that we should not remember.
- const lifetimeSeconds = 1
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
- select {
- case e := <-ndpDisp.prefixC:
- if diff := checkPrefixEvent(e, subnet, true); diff != "" {
- t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected prefix discovery event")
- }
-
- // Wait for the invalidation time plus some buffer to make sure we do
- // not actually receive any invalidation events as we should not have
- // remembered the prefix in the first place.
- clock.Advance(lifetimeSeconds * time.Second)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("should not have received any prefix events")
- default:
- }
-}
-
func TestPrefixDiscovery(t *testing.T) {
prefix1, subnet1, _ := prefixSubnetAddr(0, "")
prefix2, subnet2, _ := prefixSubnetAddr(1, "")
@@ -1606,8 +1602,7 @@ func TestPrefixDiscovery(t *testing.T) {
testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
clock := faketime.NewManualClock()
@@ -1697,17 +1692,6 @@ func TestPrefixDiscovery(t *testing.T) {
}
func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
- // Update the infinite lifetime value to a smaller value so we can test
- // that when we receive a PI with such a lifetime value, we do not
- // invalidate the prefix.
- const testInfiniteLifetimeSeconds = 2
- const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
- saved := header.NDPInfiniteLifetime
- header.NDPInfiniteLifetime = testInfiniteLifetime
- defer func() {
- header.NDPInfiniteLifetime = saved
- }()
-
prefix := tcpip.AddressWithPrefix{
Address: testutil.MustParse6("102:304:506:708::"),
PrefixLen: 64,
@@ -1715,8 +1699,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
subnet := prefix.Subnet()
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
clock := faketime.NewManualClock()
@@ -1750,9 +1733,9 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// Receive an RA with prefix in an NDP Prefix Information option (PI)
// with infinite valid lifetime which should not get invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
expectPrefixEvent(subnet, true)
- clock.Advance(testInfiniteLifetime)
+ clock.Advance(header.NDPInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1760,9 +1743,8 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// Receive an RA with finite lifetime.
- // The prefix should get invalidated after 1s.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
- clock.Advance(testInfiniteLifetime)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
+ clock.Advance(header.NDPInfiniteLifetime - time.Second)
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, false); diff != "" {
@@ -1773,23 +1755,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
}
// Receive an RA with finite lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0))
expectPrefixEvent(subnet, true)
// Receive an RA with prefix with an infinite lifetime.
// The prefix should not be invalidated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
- clock.Advance(testInfiniteLifetime)
- select {
- case <-ndpDisp.prefixC:
- t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
- default:
- }
-
- // Receive an RA with a prefix with a lifetime value greater than the
- // set infinite lifetime value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
- clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0))
+ clock.Advance(header.NDPInfiniteLifetime)
select {
case <-ndpDisp.prefixC:
t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
@@ -1806,8 +1778,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
ndpDisp := ndpDispatcher{
- prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
- rememberPrefix: true,
+ prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -1884,17 +1855,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix,
return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
}
+const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second)
+const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second)
+
// TestAutoGenAddr tests that an address is properly generated and invalidated
// when configured to do so.
func TestAutoGenAddr(t *testing.T) {
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
-
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -1903,6 +1869,7 @@ func TestAutoGenAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -1911,6 +1878,7 @@ func TestAutoGenAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil {
@@ -1960,8 +1928,9 @@ func TestAutoGenAddr(t *testing.T) {
default:
}
- // Receive an RA with prefix2 in a PI.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
+ // the minimum.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0))
expectAutoGenAddrEvent(addr2, newAddr)
if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
@@ -1971,7 +1940,7 @@ func TestAutoGenAddr(t *testing.T) {
}
// Refresh valid lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
@@ -1979,12 +1948,13 @@ func TestAutoGenAddr(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
@@ -2014,20 +1984,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t
// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
// configured to do so as part of IPv6 Privacy Extensions.
func TestAutoGenTempAddr(t *testing.T) {
- const (
- nicID = 1
- newMinVL = 5
- newMinVLDuration = newMinVL * time.Second
- )
-
- savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate
- savedMaxDesync := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
- ipv6.MaxDesyncFactor = savedMaxDesync
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- ipv6.MaxDesyncFactor = time.Nanosecond
+ const nicID = 1
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
@@ -2047,218 +2004,211 @@ func TestAutoGenTempAddr(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for i, test := range tests {
- i := i
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- seed := []byte{uint8(i)}
- var tempIIDHistory [header.IIDSize]byte
- header.InitialTempIID(tempIIDHistory[:], seed, nicID)
- newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
- return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
- }
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 2),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- DADConfigs: stack.DADConfigurations{
- DupAddrDetectTransmits: test.dupAddrTransmits,
- RetransmitTimer: test.retransmitTimer,
- },
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- TempIIDSeed: seed,
- })},
- })
-
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
-
- expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ for i, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ seed := []byte{uint8(i)}
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], seed, nicID)
+ newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 2),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ DADConfigs: stack.DADConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ },
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })},
+ Clock: clock,
+ })
- expectDADEventAsync := func(addr tcpip.Address) {
- t.Helper()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD event")
- }
- }
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("expected addr auto gen event")
}
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectDADEventAsync(addr1.Address)
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for addr auto gen event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ }
- // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
- // with non-zero valid & preferred lifetimes.
- tempAddr1 := newTempAddr(addr1.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- expectDADEventAsync(tempAddr1.Address)
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
- // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
- // with preferred lifetime > valid lifetime
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for DAD event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ }
- // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred
- // lifetimes.
- tempAddr2 := newTempAddr(addr2.Address)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- expectDADEventAsync(addr2.Address)
- expectAutoGenAddrEventAsync(tempAddr2, newAddr)
- expectDADEventAsync(tempAddr2.Address)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ default:
+ }
- // Deprecate prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectDADEventAsync(addr1.Address)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Refresh lifetimes for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ tempAddr1 := newTempAddr(addr1.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ expectDADEventAsync(tempAddr1.Address)
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Reduce valid lifetime and deprecate addresses of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- // Wait for addrs of prefix1 to be invalidated. They should be
- // invalidated at the same time.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- var nextAddr tcpip.AddressWithPrefix
- if e.addr == addr1 {
- if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = tempAddr1
- } else {
- if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- nextAddr = addr1
- }
+ // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds
+ // the minimum and won't be reached in this test.
+ tempAddr2 := newTempAddr(addr2.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ expectDADEventAsync(addr2.Address)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr)
+ expectDADEventAsync(tempAddr2.Address)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
+ // Deprecate prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Refresh lifetimes for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Reduce valid lifetime and deprecate addresses of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for addrs of prefix1 to be invalidated. They should be
+ // invalidated at the same time.
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ var nextAddr tcpip.AddressWithPrefix
+ if e.addr == addr1 {
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
+ nextAddr = tempAddr1
+ } else {
+ if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = addr1
}
- // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unexpected auto gen addr event = %+v", e)
+ if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
default:
+ t.Fatal("timed out waiting for addr auto gen event")
}
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
- t.Fatal(mismatch)
- }
- })
- }
- })
+ default:
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ })
+ }
}
// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
@@ -2266,12 +2216,6 @@ func TestAutoGenTempAddr(t *testing.T) {
func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
const nicID = 1
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- }()
- ipv6.MaxDesyncFactor = time.Nanosecond
-
tests := []struct {
name string
dupAddrTransmits uint8
@@ -2287,66 +2231,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- AutoGenTempGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- AutoGenLinkLocal: true,
- })},
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenLinkLocal: true,
+ })},
+ Clock: clock,
+ })
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // The stable link-local address should auto-generate and resolve DAD.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
+ // The stable link-local address should auto-generate and resolve DAD.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- select {
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
- t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for DAD event")
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("timed out waiting for DAD event")
+ }
- // No new addresses should be generated.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Errorf("got unxpected auto gen addr event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- })
- }
- })
+ // No new addresses should be generated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unxpected auto gen addr event = %+v", e)
+ default:
+ }
+ })
+ }
}
// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
@@ -2359,12 +2293,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
retransmitTimer = 2 * time.Second
)
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- }()
- ipv6.MaxDesyncFactor = 0
-
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
@@ -2375,6 +2303,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
DADConfigs: stack.DADConfigurations{
@@ -2388,6 +2317,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2417,12 +2347,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// Wait for DAD to complete for the stable address then expect the temporary
// address to be generated.
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
select {
@@ -2430,7 +2361,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -2439,46 +2370,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// regenerated.
func TestAutoGenTempAddrRegen(t *testing.T) {
const (
- nicID = 1
- regenAfter = 2 * time.Second
- newMinVL = 10
- newMinVLDuration = newMinVL * time.Second
- )
+ nicID = 1
+ regenAdv = 2 * time.Second
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
- }()
- ipv6.MaxDesyncFactor = 0
- ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
+ numTempAddrs = 3
+ maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate
+ )
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
+ for i := 0; i < len(tempAddrs); i++ {
+ tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ }
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: regenAdv,
+ MaxTempAddrValidLifetime: maxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate,
+ }
+ clock := faketime.NewManualClock()
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ndpConfigs,
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
+ RandSource: &randSource,
})
if err := s.CreateNIC(nicID, e); err != nil {
@@ -2501,36 +2430,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
+ tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
+ effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor
+ // The time since the last regeneration before a new temporary address is
+ // generated.
+ tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
+
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ expectAutoGenAddrEvent(tempAddrs[0], newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
+ expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
+ expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
// Wait for regeneration
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
- t.Fatal(mismatch)
- }
+ expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv)
+ expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
// Stop generating temporary addresses
ndpConfigs.AutoGenTempGlobalAddresses = false
@@ -2541,45 +2477,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
ndpEP.SetNDPConfigurations(ndpConfigs)
}
+ // Refresh lifetimes and wait for the last temporary address to be deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds))
+ expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv)
+
+ // Refresh lifetimes such that the prefix is valid and preferred forever.
+ //
+ // This should not affect the lifetimes of temporary addresses because they
+ // are capped by the maximum valid and preferred lifetimes for temporary
+ // addresses.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
+
// Wait for all the temporary addresses to get invalidated.
- tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3}
- invalidateAfter := newMinVLDuration - 2*regenAfter
+ invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{})
for _, addr := range tempAddrs {
- // Wait for a deprecation then invalidation event, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation jobs could execute in any
- // order.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we shouldn't get a deprecation
- // event after.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
- }
- } else {
- t.Fatalf("got unexpected auto-generated event = %+v", e)
- }
- case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
-
- invalidateAfter = regenAfter
+ expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter)
+ invalidateAfter = tempAddrRegenenerationTime
}
- if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" {
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" {
t.Fatal(mismatch)
}
}
@@ -2588,52 +2503,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) {
// regeneration job gets updated when refreshing the address's lifetimes.
func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
const (
- nicID = 1
- regenAfter = 2 * time.Second
- newMinVL = 10
- newMinVLDuration = newMinVL * time.Second
- )
+ nicID = 1
+ regenAdv = 2 * time.Second
- savedMaxDesyncFactor := ipv6.MaxDesyncFactor
- savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime
- savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesyncFactor
- ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
- ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
- }()
- ipv6.MaxDesyncFactor = 0
- ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration
- ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration
+ numTempAddrs = 3
+ maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate
+ maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second)
+ )
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
var tempIIDHistory [header.IIDSize]byte
header.InitialTempIID(tempIIDHistory[:], nil, nicID)
- tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
- tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix
+ for i := 0; i < len(tempAddrs); i++ {
+ tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ }
ndpDisp := ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
}
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- AutoGenTempGlobalAddresses: true,
- RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: regenAdv,
+ MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime,
+ MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2,
+ }
+ clock := faketime.NewManualClock()
+ initialTime := clock.NowMonotonic()
+ randSource := savingRandSource{
+ s: rand.NewSource(time.Now().UnixNano()),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ndpConfigs,
NDPDisp: &ndpDisp,
})},
+ Clock: clock,
+ RandSource: &randSource,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
+ tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor
+
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -2650,22 +2567,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
// Receive an RA with prefix1 in an NDP Prefix Information option (PI)
// with non-zero valid & preferred lifetimes.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
expectAutoGenAddrEvent(addr, newAddr)
- expectAutoGenAddrEvent(tempAddr1, newAddr)
- if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ expectAutoGenAddrEvent(tempAddrs[0], newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -2673,13 +2591,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
//
// A new temporary address should be generated after the regeneration
// time has passed since the prefix is deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0))
expectAutoGenAddrEvent(addr, deprecatedAddr)
- expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
select {
case e := <-ndpDisp.autoGenAddrC:
- t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ t.Fatalf("unexpected auto gen addr event = %#v", e)
+ default:
+ }
+
+ effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor
+ // The time since the last regeneration before a new temporary address is
+ // generated.
+ tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv
+
+ // Advance the clock by the regeneration time but don't expect a new temporary
+ // address as the prefix is deprecated.
+ clock.Advance(tempAddrRegenenerationTime)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %#v", e)
+ default:
}
// Prefer the prefix again.
@@ -2687,8 +2619,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// A new temporary address should immediately be generated since the
// regeneration time has already passed since the last address was generated
// - this regeneration does not depend on a job.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEvent(tempAddr2, newAddr)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds))
+ expectAutoGenAddrEvent(tempAddrs[1], newAddr)
+ // Wait for the first temporary address to be deprecated.
+ expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %s", e)
+ default:
+ }
// Increase the maximum lifetimes for temporary addresses to large values
// then refresh the lifetimes of the prefix.
@@ -2699,34 +2638,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
// regenerate a new temporary address. Note, new addresses are only
// regenerated after the preferred lifetime - the regenerate advance duration
// as paased.
- ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
- ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
+ const largeLifetimeSeconds = minVLSeconds * 2
+ const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second
+ ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime
+ ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime
ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
}
ndpEP := ipv6Ep.(ipv6.NDPEndpoint)
ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime)
+ clock.Advance(largeLifetime - timeSinceInitialTime)
+ expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr)
+ // to offset the advement of time to test the first temporary address's
+ // deprecation after the second was generated
+ advLess := regenAdv
+ expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv))
+ expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv)
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpected auto gen addr event = %+v", e)
- case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ default:
}
-
- // Set the maximum lifetimes for temporary addresses such that on the next
- // RA, the regeneration job gets scheduled again.
- //
- // The maximum lifetime is the sum of the minimum lifetimes for temporary
- // addresses + the time that has already passed since the last address was
- // generated so that the regeneration job is needed to generate the next
- // address.
- newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
- ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
- ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
- ndpEP.SetNDPConfigurations(ndpConfigs)
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
- expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
}
// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
@@ -2954,13 +2889,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
NDPConfigs: ipv6.NDPConfigurations{
@@ -2970,6 +2906,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
NDPDisp: ndpDisp,
})},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ Clock: clock,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2983,7 +2920,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil {
t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err)
}
- return ndpDisp, e, s
+ return ndpDisp, e, s, clock
}
// addrForNewConnectionTo returns the local address used when creating a new
@@ -3057,7 +2994,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3160,19 +3097,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// when its preferred lifetime expires.
func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
- const newMinVL = 2
- newMinVLDuration := newMinVL * time.Second
-
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3190,12 +3119,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
t.Helper()
+ clock.Advance(timeout)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(timeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
@@ -3213,7 +3143,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds))
expectAutoGenAddrEvent(addr2, newAddr)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
@@ -3232,7 +3162,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3241,7 +3171,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3251,6 +3181,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// addr2 should be the primary endpoint now since addr1 is deprecated but
// addr2 is not.
expectPrimaryAddr(addr2)
+
// addr1 is deprecated but if explicitly requested, it should be used.
fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
@@ -3259,7 +3190,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
// sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3271,7 +3202,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3281,7 +3212,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr1)
// Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second)
if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3295,7 +3226,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
}
// Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second)
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
@@ -3305,7 +3236,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr(addr2)
// Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds))
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
@@ -3317,6 +3248,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// cases because the deprecation and invalidation handlers could be handled in
// either deprecation then invalidation, or invalidation then deprecation
// (which should be cancelled by the invalidation handler).
+ //
+ // Since we're about to cause both events to fire, we need the dispatcher
+ // channel to be able to hold both.
+ if got, want := len(ndpDisp.autoGenAddrC), 0; got != want {
+ t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
+ }
+ if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want {
+ t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want)
+ }
+ ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2)
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
@@ -3327,21 +3269,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
} else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
+ // If we get an invalidation event first, we should not get a deprecation
// event after.
select {
case <-ndpDisp.autoGenAddrC:
t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
} else {
t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
@@ -3378,15 +3320,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
// infinite values.
func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
const infiniteVLSeconds = 2
- const minVLSeconds = 1
- savedIL := header.NDPInfiniteLifetime
- savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
- header.NDPInfiniteLifetime = savedIL
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
- header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3410,68 +3343,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
- }
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
-
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- // Receive an RA with finite prefix.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- default:
- t.Fatal("expected addr auto gen event")
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- // Receive an new RA with prefix with infinite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- // Receive a new RA with prefix with finite VL.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
- case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- })
- }
- })
+
+ default:
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
}
// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
@@ -3479,12 +3402,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
// RFC 4862 section 5.5.3.e.
func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
const infiniteVL = 4294967295
- const newMinVL = 4
- saved := ipv6.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- ipv6.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
@@ -3495,137 +3412,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
evl uint32
}{
// Should update the VL to the minimum VL for updating if the
- // new VL is less than newMinVL but was originally greater than
+ // new VL is less than minVLSeconds but was originally greater than
// it.
{
"LargeVLToVLLessThanMinVLForUpdate",
9999,
1,
- newMinVL,
+ minVLSeconds,
},
{
"LargeVLTo0",
9999,
0,
- newMinVL,
+ minVLSeconds,
},
{
"InfiniteVLToVLLessThanMinVLForUpdate",
infiniteVL,
1,
- newMinVL,
+ minVLSeconds,
},
{
"InfiniteVLTo0",
infiniteVL,
0,
- newMinVL,
+ minVLSeconds,
},
- // Should not update VL if original VL was less than newMinVL
- // and the new VL is also less than newMinVL.
+ // Should not update VL if original VL was less than minVLSeconds
+ // and the new VL is also less than minVLSeconds.
{
"ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
- newMinVL - 1,
- newMinVL - 3,
- newMinVL - 1,
+ minVLSeconds - 1,
+ minVLSeconds - 3,
+ minVLSeconds - 1,
},
// Should take the new VL if the new VL is greater than the
- // remaining time or is greater than newMinVL.
+ // remaining time or is greater than minVLSeconds.
{
"MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
- newMinVL + 5,
- newMinVL + 3,
- newMinVL + 3,
+ minVLSeconds + 5,
+ minVLSeconds + 3,
+ minVLSeconds + 3,
},
{
"SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
- newMinVL - 3,
- newMinVL - 1,
- newMinVL - 1,
+ minVLSeconds - 3,
+ minVLSeconds - 1,
+ minVLSeconds - 1,
},
{
"SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
- newMinVL - 3,
- newMinVL + 1,
- newMinVL + 1,
+ minVLSeconds - 3,
+ minVLSeconds + 1,
+ minVLSeconds + 1,
},
}
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the
- // parallel tests complete.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
- t.Run("group", func(t *testing.T) {
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- ndpDisp := ndpDispatcher{
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
- }
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
- NDPConfigs: ipv6.NDPConfigurations{
- HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
- AutoGenGlobalAddresses: true,
- },
- NDPDisp: &ndpDisp,
- })},
- })
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ NDPConfigs: ipv6.NDPConfigurations{
+ HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })},
+ Clock: clock,
+ })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
- // Receive an RA with prefix with initial VL,
- // test.ovl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
- // Receive an new RA with prefix with new VL,
- // test.nvl.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
- //
- // Validate that the VL for the address got set
- // to test.evl.
- //
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
- // The address should not be invalidated until the effective valid
- // lifetime has passed.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly received an auto gen addr event")
- case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
- }
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
+ const delta = 1
+ clock.Advance(time.Duration(test.evl)*time.Second - delta)
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ default:
+ }
- // Wait for the invalidation event.
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- case <-time.After(defaultAsyncPositiveEventTimeout):
- t.Fatal("timeout waiting for addr auto gen event")
+ // Wait for the invalidation event.
+ clock.Advance(delta)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- })
- }
- })
+ default:
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
}
// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
@@ -3696,7 +3605,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
@@ -3976,13 +3885,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
const maxMaxRetries = 3
const lifetimeSeconds = 10
- // Needed for the temporary address sub test.
- savedMaxDesync := ipv6.MaxDesyncFactor
- defer func() {
- ipv6.MaxDesyncFactor = savedMaxDesync
- }()
- ipv6.MaxDesyncFactor = time.Nanosecond
-
secretKey := makeSecretKey(t)
prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
@@ -4008,22 +3910,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
t.Helper()
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.autoGenAddrC:
if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for addr auto gen event")
}
}
- expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
+ expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
+ clock.RunImmediatelyScheduledJobs()
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
@@ -4034,15 +3938,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
+ expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
+ clock.Advance(dadTransmits * retransmitTimer)
select {
case e := <-ndpDisp.dadC:
if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
- case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ default:
t.Fatal("timed out waiting for DAD event")
}
}
@@ -4053,7 +3958,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
name string
ndpConfigs ipv6.NDPConfigurations
autoGenLinkLocal bool
- prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
+ prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
}{
{
@@ -4062,7 +3967,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled,
AutoGenGlobalAddresses: true,
},
- prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
// Receive an RA with prefix1 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
return nil
@@ -4076,7 +3981,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
name: "LinkLocal address",
ndpConfigs: ipv6.NDPConfigurations{},
autoGenLinkLocal: true,
- prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
return nil
},
addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
@@ -4090,14 +3995,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
AutoGenGlobalAddresses: true,
AutoGenTempGlobalAddresses: true,
},
- prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
+ prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
header.InitialTempIID(tempIIDHistory, nil, nicID)
// Generate a stable SLAAC address so temporary addresses will be
// generated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
- expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
+ expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
// The stable address will be assigned throughout the test.
return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
@@ -4109,14 +4014,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
for _, addrType := range addrTypes {
- // This Run will not return until the parallel tests finish.
- //
- // We need this because we need to do some teardown work after the parallel
- // tests complete and limit the number of parallel tests running at the same
- // time to reduce flakes.
- //
- // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
- // more details.
t.Run(addrType.name, func(t *testing.T) {
for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
@@ -4125,8 +4022,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
addrType := addrType
t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
- t.Parallel()
-
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent, 1),
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
@@ -4134,6 +4029,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
ndpConfigs := addrType.ndpConfigs
ndpConfigs.AutoGenAddressConflictRetries = maxRetries
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
AutoGenLinkLocal: addrType.autoGenLinkLocal,
@@ -4150,6 +4046,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
SecretKey: secretKey,
},
})},
+ Clock: clock,
})
opts := stack.NICOptions{Name: nicName}
if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
@@ -4157,12 +4054,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
var tempIIDHistory [header.IIDSize]byte
- stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:])
+ stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:])
// Simulate DAD conflicts so the address is regenerated.
for i := uint8(0); i < numFailures; i++ {
addr := addrType.addrGenFn(i, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+ expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
// Should not have any new addresses assigned to the NIC.
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
@@ -4172,7 +4069,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Simulate a DAD conflict.
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
+ expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
@@ -4182,7 +4079,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
}
- expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{})
+ expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{})
}
// Should not have any new addresses assigned to the NIC.
@@ -4194,8 +4091,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// an address after DAD resolves.
if maxRetries+1 > numFailures {
addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
- expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
- expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{})
+ expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr)
+ expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{})
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -4205,7 +4102,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
select {
case e := <-ndpDisp.autoGenAddrC:
t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ default:
}
})
}
@@ -4718,11 +4615,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
)
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, 1),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, 1),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1),
+ prefixC: make(chan ndpPrefixEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
@@ -4765,17 +4660,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
),
)
select {
- case e := <-ndpDisp.routerC:
- if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ case e := <-ndpDisp.offLinkRouteC:
+ if diff := checkOffLinkRouteEvent(e, nicID, header.IPv6EmptySubnet, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" {
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID)
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID)
}
select {
case e := <-ndpDisp.prefixC:
if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" {
- t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ t.Errorf("off-link route event mismatch (-want +got):\n%s", diff)
}
default:
t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID)
@@ -4797,8 +4692,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) {
t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err)
}
select {
- case e := <-ndpDisp.routerC:
- t.Errorf("unexpected router event = %#v", e)
+ case e := <-ndpDisp.offLinkRouteC:
+ t.Errorf("unexpected off-link route event = %#v", e)
default:
}
select {
@@ -4884,11 +4779,9 @@ func TestCleanupNDPState(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents),
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
}
clock := faketime.NewManualClock()
s := stack.New(stack.Options{
@@ -4905,14 +4798,14 @@ func TestCleanupNDPState(t *testing.T) {
Clock: clock,
})
- expectRouterEvent := func() (bool, ndpRouterEvent) {
+ expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) {
select {
- case e := <-ndpDisp.routerC:
+ case e := <-ndpDisp.offLinkRouteC:
return true, e
default:
}
- return false, ndpRouterEvent{}
+ return false, ndpOffLinkRouteEvent{}
}
expectPrefixEvent := func() (bool, ndpPrefixEvent) {
@@ -4957,8 +4850,8 @@ func TestCleanupNDPState(t *testing.T) {
// multiple addresses.
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
@@ -4968,8 +4861,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
@@ -4979,8 +4872,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
@@ -4990,8 +4883,8 @@ func TestCleanupNDPState(t *testing.T) {
}
e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ if ok, _ := expectOffLinkRouteEvent(); !ok {
+ t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2)
}
if ok, _ := expectPrefixEvent(); !ok {
t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
@@ -5032,14 +4925,14 @@ func TestCleanupNDPState(t *testing.T) {
test.cleanupFn(t, s)
// Collect invalidation events after having NDP state cleaned up.
- gotRouterEvents := make(map[ndpRouterEvent]int)
+ gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
- ok, e := expectRouterEvent()
+ ok, e := expectOffLinkRouteEvent()
if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
break
}
- gotRouterEvents[e]++
+ gotOffLinkRouteEvents[e]++
}
gotPrefixEvents := make(map[ndpPrefixEvent]int)
for i := 0; i < maxRouterAndPrefixEvents; i++ {
@@ -5066,14 +4959,14 @@ func TestCleanupNDPState(t *testing.T) {
t.FailNow()
}
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1,
+ {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1,
}
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" {
+ t.Errorf("off-link route events mismatch (-want +got):\n%s", diff)
}
expectedPrefixEvents := map[ndpPrefixEvent]int{
{nicID: nicID1, prefix: subnet1, discovered: false}: 1,
@@ -5137,8 +5030,8 @@ func TestCleanupNDPState(t *testing.T) {
// cancelled when the NDP state was cleaned up).
clock.Advance(lifetimeSeconds * time.Second)
select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
+ case <-ndpDisp.offLinkRouteC:
+ t.Error("unexpected off-link route event")
default:
}
select {
@@ -5163,7 +5056,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
ndpDisp := ndpDispatcher{
dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
- rememberRouter: true,
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 378389db2..b854d868c 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -779,17 +779,11 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt
transProto := state.proto
- // Raw socket packets are delivered based solely on the transport
- // protocol number. We do not inspect the payload to ensure it's
- // validly formed.
- n.stack.demux.deliverRawPacket(protocol, pkt)
-
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
if pkt.TransportHeader().View().IsEmpty() {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
// ICMP packets may be longer, but until icmp.Parse is implemented, here
// we parse it using the minimum size.
@@ -878,6 +872,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo
}
}
+// DeliverRawPacket implements TransportDispatcher.
+func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
+ // For ICMPv4 only we validate the header length for compatibility with
+ // raw(7) ICMP_FILTER. The same check is made in Linux here:
+ // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189.
+ if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize {
+ return
+ }
+ n.stack.demux.deliverRawPacket(protocol, pkt)
+}
+
// ID implements NetworkInterface.
func (n *nic) ID() tcpip.NICID {
return n.id
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 4ca702121..9192d8433 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -134,7 +134,7 @@ type PacketBuffer struct {
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
- // NICID is the ID of the interface the network packet was received at.
+ // NICID is the ID of the last interface the network packet was handled at.
NICID tcpip.NICID
// RXTransportChecksumValidated indicates that transport checksum verification
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index a038389e0..dfe2c886f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -265,6 +265,11 @@ type TransportDispatcher interface {
//
// DeliverTransportError takes ownership of the packet buffer.
DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer)
+
+ // DeliverRawPacket delivers a packet to any subscribed raw sockets.
+ //
+ // DeliverRawPacket does NOT take ownership of the packet buffer.
+ DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 40d277312..81fabe29a 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -108,7 +108,7 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
- // TODO(gvisor.dev/issue/170): S/R this field.
+ // TODO(gvisor.dev/issue/4595): S/R this field.
tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
@@ -1872,9 +1872,8 @@ const (
// ParsePacketBufferTransport parses the provided packet buffer's transport
// header.
func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult {
- // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
- // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
- // full explanation.
+ // ICMP packets don't have their TransportHeader fields set yet, parse it
+ // here. See icmp/protocol.go:protocol.Parse for a full explanation.
if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
return ParsedOK
}
diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go
index e90c1a770..90a8ba6cf 100644
--- a/pkg/tcpip/stack/tcp.go
+++ b/pkg/tcpip/stack/tcp.go
@@ -380,9 +380,6 @@ type TCPSndBufState struct {
// SndClosed indicates that the endpoint has been closed for sends.
SndClosed bool
- // SndBufInQueue is the number of bytes in the send queue.
- SndBufInQueue seqnum.Size
-
// PacketTooBigCount is used to notify the main protocol routine how
// many times a "packet too big" control packet is received.
PacketTooBigCount int
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 8a8454a6a..dda57e225 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -215,10 +216,17 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t
netProto: netProto,
transProto: transProto,
}
- epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- return multiPortEp.singleRegisterEndpoint(t, flags)
+ if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil {
+ return err
+ }
+ // Only add this newly created multiportEndpoint if the singleRegisterEndpoint
+ // succeeded.
+ if !ok {
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
+ }
+ return nil
}
func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
@@ -405,7 +413,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
-
bits := flags.Bits() & ports.MultiBindFlagMask
if len(ep.endpoints) != 0 {
@@ -468,17 +475,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
-
epsByNIC, ok := eps.endpoints[id]
if !ok {
epsByNIC = &endpointsByNIC{
endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
seed: d.stack.Seed(),
}
+ }
+ if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil {
+ return err
+ }
+ // Only add this newly created epsByNIC if registerEndpoint succeeded.
+ if !ok {
eps.endpoints[id] = epsByNIC
}
-
- return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+ return nil
}
func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 0972c94de..45b09110d 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -203,6 +203,56 @@ func TestTransportDemuxerRegister(t *testing.T) {
}
}
+func TestTransportDemuxerRegisterMultiple(t *testing.T) {
+ type test struct {
+ flags ports.Flags
+ want tcpip.Error
+ }
+ for _, subtest := range []struct {
+ name string
+ tests []test
+ }{
+ {"zeroFlags", []test{
+ {ports.Flags{}, nil},
+ {ports.Flags{}, &tcpip.ErrPortInUse{}},
+ }},
+ {"multibindFlags", []test{
+ // Allow multiple registrations same TransportEndpointID with multibind flags.
+ {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
+ {ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
+ // Disallow registration w/same ID for a non-multibindflag.
+ {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}},
+ }},
+ } {
+ t.Run(subtest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+ var eps []tcpip.Endpoint
+ for idx, test := range subtest.tests {
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ eps = append(eps, ep)
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ id := stack.TransportEndpointID{LocalPort: 1}
+ if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want {
+ t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want)
+ }
+ }
+ for _, ep := range eps {
+ ep.Close()
+ }
+ })
+ }
+}
+
// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
func TestBindToDeviceDistribution(t *testing.T) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 91622fa4c..8f2658f64 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -465,11 +465,11 @@ type ControlMessages struct {
// PacketOwner is used to get UID and GID of the packet.
type PacketOwner interface {
- // UID returns UID of the packet.
- UID() uint32
+ // UID returns KUID of the packet.
+ KUID() uint32
- // GID returns GID of the packet.
- GID() uint32
+ // GID returns KGID of the packet.
+ KGID() uint32
}
// ReadOptions contains options for Endpoint.Read.
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 07ba2b837..f9ab7d0af 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
// Make sure the packet is not dropped by the next rule.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err)
}
},
genPacket: genPacketV6,
@@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) {
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}}
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err)
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err)
}
},
genPacket: genPacketV4,
@@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) {
}
if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil {
- t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err)
+ t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err)
}
},
genPacket: func(r *stack.Route) stack.PacketBufferList {
@@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr
checker.ICMPv6Type(header.ICMPv6EchoReply)))
}
+func boolToInt(v bool) uint64 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
+ return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
+ t.Helper()
+
+ ipv6 := netProto == ipv6.ProtocolNumber
+
+ ipt := s.IPTables()
+ filter := ipt.GetTable(stack.FilterID, ipv6)
+ ruleIdx := filter.BuiltinChains[hook]
+ filter.Rules[ruleIdx].Filter = f
+ filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
+ // Make sure the packet is not dropped by the next rule.
+ filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
+ if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
+ }
+ }
+}
+
func TestForwardingHook(t *testing.T) {
const (
nicID1 = 1
@@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) {
},
}
- setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) {
- return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) {
- t.Helper()
-
- ipv6 := netProto == ipv6.ProtocolNumber
-
- ipt := s.IPTables()
- filter := ipt.GetTable(stack.FilterID, ipv6)
- ruleIdx := filter.BuiltinChains[stack.Forward]
- filter.Rules[ruleIdx].Filter = f
- filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto}
- // Make sure the packet is not dropped by the next rule.
- filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto}
- if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil {
- t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err)
- }
- }
- }
-
- boolToInt := func(v bool) uint64 {
- if v {
- return 1
- }
- return 0
- }
-
subTests := []struct {
name string
setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
@@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) {
{
name: "Drop",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}),
expectForward: false,
},
{
name: "Drop with input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}),
expectForward: false,
},
{
name: "Drop with output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}),
expectForward: false,
},
{
name: "Drop with other input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}),
expectForward: true,
},
{
name: "Drop with input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with other input and other output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}),
expectForward: true,
},
{
name: "Drop with inverted input NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}),
expectForward: true,
},
{
name: "Drop with inverted output NIC filtering",
- setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
+ setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}),
expectForward: true,
},
}
@@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) {
})
}
}
+
+func TestInputHookWithLocalForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Name = "nic1"
+ nic2Name = "nic2"
+
+ otherNICName = "otherNIC"
+ )
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ rx func(*channel.Endpoint)
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4EchoReply)))
+ },
+ },
+ {
+ name: "IPv6",
+ netProto: ipv6.ProtocolNumber,
+ rx: func(e *channel.Endpoint) {
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl)
+ },
+ checker: func(t *testing.T, b []byte) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address),
+ checker.DstAddr(utils.RemoteIPv6Addr),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoReply)))
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber)
+ expectDrop bool
+ }{
+ {
+ name: "Accept",
+ setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ },
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on arrival NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}),
+ expectDrop: true,
+ },
+ {
+ name: "Drop with input NIC filtering on delivered NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}),
+ expectDrop: false,
+ },
+
+ {
+ name: "Drop with input NIC filtering on other NIC",
+ setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}),
+ expectDrop: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+
+ subTest.setupFilter(t, s, test.netProto)
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err)
+ }
+ if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err)
+ }
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID1,
+ },
+ })
+
+ test.rx(e1)
+
+ ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err)
+ }
+ ep1Stats := ep1.Stats()
+ ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats)
+ }
+ ip1Stats := ipEP1Stats.IPStats()
+
+ if got := ip1Stats.PacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want {
+ t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want)
+ }
+
+ ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto)
+ if err != nil {
+ t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err)
+ }
+ ep2Stats := ep2.Stats()
+ ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats)
+ if !ok {
+ t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats)
+ }
+ ip2Stats := ipEP2Stats.IPStats()
+ if got := ip2Stats.PacketsReceived.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got)
+ }
+ if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 {
+ t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got)
+ }
+ if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want {
+ t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want)
+ }
+ if got := ip2Stats.PacketsSent.Value(); got != 0 {
+ t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got)
+ }
+
+ if p, ok := e1.Read(); ok == subTest.expectDrop {
+ t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop)
+ } else if !subTest.expectDrop {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ if p, ok := e2.Read(); ok {
+ t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 87d36e1dd..b2008f0b2 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -44,20 +44,17 @@ type ndpDispatcher struct{}
func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
-func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
- return false
+func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) {
}
-func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {}
+func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {}
-func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool {
- return false
+func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) {
}
func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {}
-func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool {
- return true
+func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) {
}
func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index fb77febcf..f9a15efb2 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -213,6 +213,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.state {
case stateInitial:
@@ -229,10 +230,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
@@ -758,8 +757,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
switch e.NetProto {
case header.IPv4ProtocolNumber:
h := header.ICMPv4(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -767,8 +764,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
case header.IPv6ProtocolNumber:
h := header.ICMPv6(pkt.TransportHeader().View())
- // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed
- // after early parsing.
if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 47f7dd1cb..fa82affc1 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -123,8 +123,6 @@ func (*protocol) Wait() {}
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
- //
// Right now, the Parse() method is tied to enabled protocols passed into
// stack.New. This works for UDP and TCP, but we handle ICMP traffic even
// when netstack users don't pass ICMP as a supported protocol.
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index cd8c99d41..8e7bb6c6e 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -208,7 +208,6 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
}
func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) {
- // TODO(gvisor.dev/issue/173): Implement.
return 0, &tcpip.ErrInvalidOptionValue{}
}
@@ -244,8 +243,6 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- // TODO(gvisor.dev/issue/173): Add Bind support.
-
// "By default, all packets of the specified protocol type are passed
// to a packet socket. To get packets only from a specific interface
// use bind(2) specifying an address in a struct sockaddr_ll to bind
@@ -385,7 +382,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(gvisor.dev/issue/173): Return network protocol.
if !pkt.LinkHeader().View().IsEmpty() {
// Get info directly from the ethernet header.
hdr := header.Ethernet(pkt.LinkHeader().View())
@@ -424,7 +420,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress,
default:
panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
}
-
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 1bce2769a..b6687911a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -286,26 +286,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return nil, nil, nil, &tcpip.ErrBadBuffer{}
}
- // If this is an unassociated socket and callee provided a nonzero
- // destination address, route using that address.
- if e.ops.GetHeaderIncluded() {
- ip := header.IPv4(payloadBytes)
- if !ip.IsValid(len(payloadBytes)) {
- return nil, nil, nil, &tcpip.ErrInvalidOptionValue{}
- }
- dstAddr := ip.DestinationAddress()
- // Update dstAddr with the address in the IP header, unless
- // opts.To is set (e.g. if sendto specifies a specific
- // address).
- if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
- opts.To = &tcpip.FullAddress{
- NIC: 0, // NIC is unset.
- Addr: dstAddr, // The address from the payload.
- Port: 0, // There are no ports here.
- }
- }
- }
-
// Did the user caller provide a destination? If not, use the connected
// destination.
if opts.To == nil {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 2c65b737d..aa413ad05 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -330,7 +330,9 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions,
}
ep := h.ep
- if err := h.complete(); err != nil {
+ // N.B. the endpoint is generated above by startHandshake, and will be
+ // returned locked. This first call is forced.
+ if err := h.complete(); err != nil { // +checklocksforce
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.stats.FailedConnectionAttempts.Increment()
l.cleanupFailedHandshake(h)
@@ -364,6 +366,7 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// Precondition: h.ep.mu must be held.
+// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
@@ -504,7 +507,9 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- if err := h.complete(); err != nil {
+ // Note that startHandshake returns a locked endpoint. The
+ // force call here just makes it so.
+ if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
@@ -560,6 +565,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
switch {
+ case s.flags.Contains(header.TCPFlagRst):
+ e.stack.Stats().DroppedPackets.Increment()
+ return nil
+
case s.flags == header.TCPFlagSyn:
if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
@@ -611,7 +620,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
return nil
- case (s.flags & header.TCPFlagAck) != 0:
+ case s.flags.Contains(header.TCPFlagAck):
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -736,6 +745,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
mss: rcvdSynOptions.MSS,
})
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) {
+ s.incRef()
+ n.newSegmentWaker.Assert()
+ }
+
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
// the application is slow to accept or stops
@@ -753,6 +769,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
default:
+ e.stack.Stats().DroppedPackets.Increment()
return nil
}
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 570e5081c..93ed161f9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -406,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
h.ep.transitionToStateEstablishedLocked(h)
- // If the segment has data then requeue it for the receiver
- // to process it again once main loop is started.
- if s.data.Size() > 0 {
+ // Requeue the segment if the ACK completing the handshake has more info
+ // to be procesed by the newly established endpoint.
+ if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) {
s.incRef()
- h.ep.enqueueSegment(s)
+ h.ep.newSegmentWaker.Assert()
}
return nil
}
@@ -511,6 +511,7 @@ func (h *handshake) start() {
}
// complete completes the TCP 3-way handshake initiated by h.start().
+// +checklocks:h.ep.mu
func (h *handshake) complete() tcpip.Error {
// Set up the wakers.
var s sleep.Sleeper
@@ -909,30 +910,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se
return err
}
-func (e *endpoint) handleWrite() {
- e.sndQueueInfo.sndQueueMu.Lock()
- next := e.drainSendQueueLocked()
- e.sndQueueInfo.sndQueueMu.Unlock()
-
- e.sendData(next)
-}
-
-// Move packets from send queue to send list.
-//
-// Precondition: e.sndBufMu must be locked.
-func (e *endpoint) drainSendQueueLocked() *segment {
- first := e.sndQueueInfo.sndQueue.Front()
- if first != nil {
- e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue)
- e.sndQueueInfo.SndBufInQueue = 0
- }
- return first
-}
-
// Precondition: e.mu must be locked.
func (e *endpoint) sendData(next *segment) {
// Initialize the next segment to write if it's currently nil.
if e.snd.writeNext == nil {
+ if next == nil {
+ return
+ }
e.snd.writeNext = next
}
@@ -940,17 +924,6 @@ func (e *endpoint) sendData(next *segment) {
e.snd.sendData()
}
-func (e *endpoint) handleClose() {
- if !e.EndpointState().connected() {
- return
- }
- // Drain the send queue.
- e.handleWrite()
-
- // Mark send side as closed.
- e.snd.Closed = true
-}
-
// resetConnectionLocked puts the endpoint in an error state with the given
// error code and sends a RST if and only if the error is not ErrConnectionReset
// indicating that the connection is being reset due to receiving a RST. This
@@ -1130,7 +1103,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) {
func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
- if e.EndpointState().closed() {
+ if state := e.EndpointState(); state.closed() || state == StateTimeWait {
return nil
}
s := e.segmentQueue.dequeue()
@@ -1311,42 +1284,45 @@ func (e *endpoint) disableKeepaliveTimer() {
e.keepalive.Unlock()
}
-// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
-// goroutine and is responsible for sending segments and handling received
-// segments.
-func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
- e.mu.Lock()
- var closeTimer tcpip.Timer
- var closeWaker sleep.Waker
-
- epilogue := func() {
- // e.mu is expected to be hold upon entering this section.
- if e.snd != nil {
- e.snd.resendTimer.cleanup()
- e.snd.probeTimer.cleanup()
- e.snd.reorderTimer.cleanup()
- }
+// protocolMainLoopDone is called at the end of protocolMainLoop.
+// +checklocksrelease:e.mu
+func (e *endpoint) protocolMainLoopDone(closeTimer tcpip.Timer, closeWaker *sleep.Waker) {
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ e.snd.probeTimer.cleanup()
+ e.snd.reorderTimer.cleanup()
+ }
- if closeTimer != nil {
- closeTimer.Stop()
- }
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
- e.completeWorkerLocked()
+ e.completeWorkerLocked()
- if e.drainDone != nil {
- close(e.drainDone)
- }
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
- e.mu.Unlock()
+ e.mu.Unlock()
- e.drainClosingSegmentQueue()
+ e.drainClosingSegmentQueue()
- // When the protocol loop exits we should wake up our waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
- }
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
+}
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error {
+ var (
+ closeTimer tcpip.Timer
+ closeWaker sleep.Waker
+ )
+
+ e.mu.Lock()
if handshake {
- if err := e.h.complete(); err != nil {
+ if err := e.h.complete(); err != nil { // +checklocksforce
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -1355,8 +1331,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.hardError = err
e.workerCleanup = true
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return err
}
}
@@ -1402,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
{
w: &e.sndQueueInfo.sndWaker,
f: func() tcpip.Error {
- e.handleWrite()
- return nil
- },
- },
- {
- w: &e.sndQueueInfo.sndCloseWaker,
- f: func() tcpip.Error {
- e.handleClose()
+ e.sendData(nil /* next */)
return nil
},
},
@@ -1474,11 +1442,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
return &tcpip.ErrConnectionReset{}
}
- if n&notifyClose != 0 && closeTimer == nil {
- if e.EndpointState() == StateFinWait2 && e.closed {
+ if n&notifyClose != 0 && e.closed {
+ switch e.EndpointState() {
+ case StateEstablished:
+ // Perform full shutdown if the endpoint is still
+ // established. This can occur when notifyClose
+ // was asserted just before becoming established.
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ case StateFinWait2:
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
- closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ if closeTimer == nil {
+ closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ }
}
}
@@ -1499,7 +1475,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
- e.mu.Unlock()
+ e.mu.Unlock() // +checklocksforce
<-e.undrain
e.mu.Lock()
}
@@ -1560,8 +1536,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
if err != nil {
e.resetConnectionLocked(err)
}
- // Lock released below.
- epilogue()
}
loop:
@@ -1585,6 +1559,7 @@ loop:
// just want to terminate the loop and cleanup the
// endpoint.
cleanupOnError(nil)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
case StateTimeWait:
fallthrough
@@ -1593,6 +1568,7 @@ loop:
default:
if err := funcs[v].f(); err != nil {
cleanupOnError(err)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
}
}
@@ -1616,13 +1592,13 @@ loop:
// Handle any StateError transition from StateTimeWait.
if e.EndpointState() == StateError {
cleanupOnError(nil)
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
return nil
}
e.transitionToStateCloseLocked()
- // Lock released below.
- epilogue()
+ e.protocolMainLoopDone(closeTimer, &closeWaker)
// A new SYN was received during TIME_WAIT and we need to abort
// the timewait and redirect the segment to the listener queue
@@ -1692,6 +1668,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
// should be executed after releasing the endpoint registrations. This is
// done in cases where a new SYN is received during TIME_WAIT that carries
// a sequence number larger than one see on the connection.
+// +checklocks:e.mu
func (e *endpoint) doTimeWait() (twReuse func()) {
// Trigger a 2 * MSL time wait state. During this period
// we will drop all incoming segments.
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index dff7cb89c..7d110516b 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -127,7 +127,7 @@ func (p *processor) start(wg *sync.WaitGroup) {
case !ep.segmentQueue.empty():
p.epQ.enqueue(ep)
}
- ep.mu.Unlock()
+ ep.mu.Unlock() // +checklocksforce
} else {
ep.newSegmentWaker.Assert()
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a27e2110b..1ed4ba419 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -293,16 +293,9 @@ type sndQueueInfo struct {
sndQueueMu sync.Mutex `state:"nosave"`
stack.TCPSndBufState
- // sndQueue holds segments that are ready to be sent.
- sndQueue segmentList `state:"wait"`
-
- // sndWaker is used to signal the protocol goroutine when segments are
- // added to the `sndQueue`.
+ // sndWaker is used to signal the protocol goroutine when there may be
+ // segments that need to be sent.
sndWaker sleep.Waker `state:"manual"`
-
- // sndCloseWaker is used to notify the protocol goroutine when the send
- // side is closed.
- sndCloseWaker sleep.Waker `state:"manual"`
}
// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata.
@@ -671,6 +664,7 @@ func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 {
// The assumption behind spinning here being that background packet processing
// should not be holding the lock for long and spinning reduces latency as we
// avoid an expensive sleep/wakeup of of the syscall goroutine).
+// +checklocksacquire:e.mu
func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
@@ -690,7 +684,7 @@ func (e *endpoint) LockUser() {
continue
}
atomic.StoreUint32(&e.ownedByUser, 1)
- return
+ return // +checklocksforce
}
}
@@ -707,7 +701,7 @@ func (e *endpoint) LockUser() {
// protocol goroutine altogether.
//
// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) UnlockUser() {
// Lock segment queue before checking so that we avoid a race where
// segments can be queued between the time we check if queue is empty
@@ -743,12 +737,13 @@ func (e *endpoint) UnlockUser() {
}
// StopWork halts packet processing. Only to be used in tests.
+// +checklocksacquire:e.mu
func (e *endpoint) StopWork() {
e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
-// +checklocks:e.mu
+// +checklocksrelease:e.mu
func (e *endpoint) ResumeWork() {
e.mu.Unlock()
}
@@ -1487,87 +1482,95 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) {
return avail, nil
}
-// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // Linux completely ignores any address passed to sendto(2) for TCP sockets
- // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
- // and opts.EndOfRecord are also ignored.
+// readFromPayloader reads a slice from the Payloader.
+// +checklocks:e.mu
+// +checklocks:e.sndQueueInfo.sndQueueMu
+func (e *endpoint) readFromPayloader(p tcpip.Payloader, opts tcpip.WriteOptions, avail int) ([]byte, tcpip.Error) {
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndQueueInfo.sndQueueMu.Unlock()
+ defer e.sndQueueInfo.sndQueueMu.Lock()
- e.LockUser()
- defer e.UnlockUser()
+ e.UnlockUser()
+ defer e.LockUser()
+ }
- nextSeg, n, err := func() (*segment, int, tcpip.Error) {
- e.sndQueueInfo.sndQueueMu.Lock()
- defer e.sndQueueInfo.sndQueueMu.Unlock()
+ // Fetch data.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return nil, nil
+ }
+ v := make([]byte, avail)
+ n, err := p.Read(v)
+ if err != nil && err != io.EOF {
+ return nil, &tcpip.ErrBadBuffer{}
+ }
+ return v[:n], nil
+}
+// queueSegment reads data from the payloader and returns a segment to be sent.
+// +checklocks:e.mu
+func (e *endpoint) queueSegment(p tcpip.Payloader, opts tcpip.WriteOptions) (*segment, int, tcpip.Error) {
+ e.sndQueueInfo.sndQueueMu.Lock()
+ defer e.sndQueueInfo.sndQueueMu.Unlock()
+
+ avail, err := e.isEndpointWritableLocked()
+ if err != nil {
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return nil, 0, err
+ }
+
+ v, err := e.readFromPayloader(p, opts, avail)
+ if err != nil {
+ return nil, 0, err
+ }
+ if !opts.Atomic {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.stats.WriteErrors.WriteClosed.Increment()
return nil, 0, err
}
- v, err := func() ([]byte, tcpip.Error) {
- // We can release locks while copying data.
- //
- // This is not possible if atomic is set, because we can't allow the
- // available buffer space to be consumed by some other caller while we
- // are copying data in.
- if !opts.Atomic {
- e.sndQueueInfo.sndQueueMu.Unlock()
- defer e.sndQueueInfo.sndQueueMu.Lock()
-
- e.UnlockUser()
- defer e.LockUser()
- }
-
- // Fetch data.
- if l := p.Len(); l < avail {
- avail = l
- }
- if avail == 0 {
- return nil, nil
- }
- v := make([]byte, avail)
- n, err := p.Read(v)
- if err != nil && err != io.EOF {
- return nil, &tcpip.ErrBadBuffer{}
- }
- return v[:n], nil
- }()
- if len(v) == 0 || err != nil {
- return nil, 0, err
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
+ }
- if !opts.Atomic {
- // Since we released locks in between it's possible that the
- // endpoint transitioned to a CLOSED/ERROR states so make
- // sure endpoint is still writable before trying to write.
- avail, err := e.isEndpointWritableLocked()
- if err != nil {
- e.stats.WriteErrors.WriteClosed.Increment()
- return nil, 0, err
- }
+ // Add data to the send queue.
+ s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
+ e.sndQueueInfo.SndBufUsed += len(v)
+ e.snd.writeList.PushBack(s)
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
- }
+ return s, len(v), nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
- // Add data to the send queue.
- s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v)
- e.sndQueueInfo.SndBufUsed += len(v)
- e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v))
- e.sndQueueInfo.sndQueue.PushBack(s)
+ e.LockUser()
+ defer e.UnlockUser()
- return e.drainSendQueueLocked(), len(v), nil
- }()
// Return if either we didn't queue anything or if an error occurred while
// attempting to queue data.
+ nextSeg, n, err := e.queueSegment(p, opts)
if n == 0 || err != nil {
return 0, err
}
+
e.sendData(nextSeg)
return int64(n), nil
}
@@ -2314,7 +2317,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
// connection setting here.
if !handshake {
e.segmentQueue.mu.Lock()
- for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} {
+ for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.TransportEndpointInfo.ID
e.sndQueueInfo.sndWaker.Assert()
@@ -2372,6 +2375,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
+ // Wake up any readers that maybe waiting for the stream to become
+ // readable.
+ e.waiterQueue.Notify(waiter.ReadableEvents)
}
// Close for write.
@@ -2388,12 +2394,20 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Queue fin segment.
s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil)
- e.sndQueueInfo.sndQueue.PushBack(s)
- e.sndQueueInfo.SndBufInQueue++
+ e.snd.writeList.PushBack(s)
// Mark endpoint as closed.
e.sndQueueInfo.SndClosed = true
e.sndQueueInfo.sndQueueMu.Unlock()
- e.handleClose()
+
+ // Drain the send queue.
+ e.sendData(s)
+
+ // Mark send side as closed.
+ e.snd.Closed = true
+
+ // Wake up any writers that maybe waiting for the stream to become
+ // writable.
+ e.waiterQueue.Notify(waiter.WritableEvents)
}
return nil
@@ -2501,6 +2515,7 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
+// +checklocksrelease:e.mu
func (e *endpoint) startAcceptedLoop() {
e.workerRunning = true
e.mu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 65c86823a..2e709ed78 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -164,8 +164,9 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
return nil, err
}
- // Start the protocol goroutine.
- ep.startAcceptedLoop()
+ // Start the protocol goroutine. Note that the endpoint is returned
+ // from performHandshake locked.
+ ep.startAcceptedLoop() // +checklocksforce
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 661ca604a..9ce8fcae9 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -559,7 +559,6 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn
// (2) returns to TIME-WAIT state if the SYN turns out
// to be an old duplicate".
if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) {
-
return false, true
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index e7ede7662..71c4aa85d 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3451,17 +3451,13 @@ loop:
for {
switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
case *tcpip.ErrWouldBlock:
- select {
- case <-ch:
- // Expect the state to be StateError and subsequent Reads to fail with HardError.
- _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
- if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
- t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
- }
- break loop
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for reset to arrive")
+ <-ch
+ // Expect the state to be StateError and subsequent Reads to fail with HardError.
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
+ if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" {
+ t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d)
}
+ break loop
case *tcpip.ErrConnectionReset:
break loop
default:
@@ -3472,14 +3468,27 @@ loop:
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
- if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
- t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
+
+ checkValid := func() []error {
+ var errors []error
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got))
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got))
+ }
+ return errors
}
- if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
}
- if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
- t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ for _, err := range checkValid() {
+ t.Error(err)
}
}
@@ -6068,6 +6077,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// complete the connection to test that the large SEQ num
// did not change the state from SYN-RCVD.
+ // Get setup to be notified about connection establishment.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.ReadableEvents)
+ defer c.WQ.EventUnregister(&we)
+
// Send ACK to move to ESTABLISHED state.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -6078,32 +6092,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
RcvWnd: 30000,
})
+ <-ch
newEP, _, err := c.EP.Accept(nil)
- switch err.(type) {
- case nil, *tcpip.ErrWouldBlock:
- default:
+ if err != nil {
t.Fatalf("Accept failed: %s", err)
}
- if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
- // Try to accept the connections in the backlog.
- we, ch := waiter.NewChannelEntry(nil)
- c.WQ.EventRegister(&we, waiter.ReadableEvents)
- defer c.WQ.EventUnregister(&we)
-
- // Wait for connection to be established.
- select {
- case <-ch:
- newEP, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
- }
- }
-
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
var r strings.Reader
@@ -6209,12 +6203,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
RcvWnd: 30000,
})
- time.Sleep(50 * time.Millisecond)
- if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
+ checkValid := func() []error {
+ var errors []error
+ if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want))
+ }
+ if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
+ errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want))
+ }
+ return errors
}
- if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
+
+ start := time.Now()
+ for time.Since(start) < time.Minute && len(checkValid()) > 0 {
+ time.Sleep(50 * time.Millisecond)
+ }
+ for _, err := range checkValid() {
+ t.Error(err)
+ }
+ if t.Failed() {
+ t.FailNow()
}
we, ch := waiter.NewChannelEntry(nil)
@@ -6225,19 +6233,62 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
_, _, err = c.EP.Accept(nil)
if cmp.Equal(&tcpip.ErrWouldBlock{}, err) {
// Wait for connection to be established.
- select {
- case <-ch:
- _, _, err = c.EP.Accept(nil)
- if err != nil {
- t.Fatalf("Accept failed: %s", err)
- }
-
- case <-time.After(1 * time.Second):
- t.Fatalf("Timed out waiting for accept")
+ <-ch
+ _, _, err = c.EP.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
}
}
}
+func TestListenDropIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ c.Create(-1 /*epRcvBuf*/)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := c.EP.Listen(1 /*backlog*/); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ initialDropped := stats.DroppedPackets.Value()
+
+ // Send RST, FIN segments, that are expected to be dropped by the listener.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ })
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin,
+ })
+
+ // To ensure that the RST, FIN sent earlier are indeed received and ignored
+ // by the listener, send a SYN and wait for the SYN to be ACKd.
+ irs := seqnum.Value(context.TestInitialSequenceNumber)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ })
+ checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs)+1),
+ ))
+
+ if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want {
+ t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
func TestEndpointBindListenAcceptState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -7435,7 +7486,7 @@ func TestTCPUserTimeout(t *testing.T) {
select {
case <-notifyCh:
case <-time.After(2 * initRTO):
- t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout)
+ t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout)
}
// No packet should be received as the connection should be silently
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index def9d7186..82a3f2287 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -364,6 +364,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
+// +checklocks:e.mu
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
switch e.EndpointState() {
case StateInitial:
@@ -380,10 +381,8 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip
}
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
+ defer e.mu.DowngradeLock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
@@ -449,37 +448,20 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return n, err
}
-func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- if err := e.LastError(); err != nil {
- return 0, err
- }
-
- // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
-
- to := opts.To
-
+func (e *endpoint) buildUDPPacketInfo(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) {
e.mu.RLock()
- lockReleased := false
- defer func() {
- if lockReleased {
- return
- }
- e.mu.RUnlock()
- }()
+ defer e.mu.RUnlock()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
- return 0, &tcpip.ErrClosedForSend{}
+ return udpPacketInfo{}, &tcpip.ErrClosedForSend{}
}
// Prepare for write.
for {
- retry, err := e.prepareForWrite(to)
+ retry, err := e.prepareForWrite(opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
if !retry {
@@ -489,34 +471,34 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
route := e.route
dstPort := e.dstPort
- if to != nil {
+ if opts.To != nil {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicID := to.NIC
+ nicID := opts.To.NIC
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
- return 0, &tcpip.ErrNoRoute{}
+ return udpPacketInfo{}, &tcpip.ErrNoRoute{}
}
nicID = e.BindNICID
}
- if to.Port == 0 {
+ if opts.To.Port == 0 {
// Port 0 is an invalid port to send to.
- return 0, &tcpip.ErrInvalidEndpointState{}
+ return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{}
}
- dst, netProto, err := e.checkV4MappedLocked(*to)
+ dst, netProto, err := e.checkV4MappedLocked(*opts.To)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
- return 0, err
+ return udpPacketInfo{}, err
}
defer r.Release()
@@ -525,12 +507,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
- return 0, &tcpip.ErrBroadcastDisabled{}
+ return udpPacketInfo{}, &tcpip.ErrBroadcastDisabled{}
}
v := make([]byte, p.Len())
if _, err := io.ReadFull(p, v); err != nil {
- return 0, &tcpip.ErrBadBuffer{}
+ return udpPacketInfo{}, &tcpip.ErrBadBuffer{}
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
@@ -548,24 +530,39 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
v,
)
}
- return 0, &tcpip.ErrMessageTooLong{}
+ return udpPacketInfo{}, &tcpip.ErrMessageTooLong{}
}
ttl := e.ttl
useDefaultTTL := ttl == 0
-
if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) {
ttl = e.multicastTTL
// Multicast allows a 0 TTL.
useDefaultTTL = false
}
- localPort := e.ID.LocalPort
- sendTOS := e.sendTOS
- owner := e.owner
- noChecksum := e.SocketOptions().GetNoChecksum()
- lockReleased = true
- e.mu.RUnlock()
+ return udpPacketInfo{
+ route: route,
+ data: buffer.View(v),
+ localPort: e.ID.LocalPort,
+ remotePort: dstPort,
+ ttl: ttl,
+ useDefaultTTL: useDefaultTTL,
+ tos: e.sendTOS,
+ owner: e.owner,
+ noChecksum: e.SocketOptions().GetNoChecksum(),
+ }, nil
+}
+
+func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ if err := e.LastError(); err != nil {
+ return 0, err
+ }
+
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, &tcpip.ErrInvalidOptionValue{}
+ }
// Do not hold lock when sending as loopback is synchronous and if the UDP
// datagram ends up generating an ICMP response then it can result in a
@@ -577,10 +574,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
//
// See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read
// locking is prohibited.
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil {
+ u, err := e.buildUDPPacketInfo(p, opts)
+ if err != nil {
return 0, err
}
- return int64(len(v)), nil
+ n, err := u.send()
+ if err != nil {
+ return 0, err
+ }
+ return int64(n), nil
}
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
@@ -817,14 +819,30 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
return nil
}
-// sendUDP sends a UDP segment via the provided network endpoint and under the
-// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) tcpip.Error {
+// udpPacketInfo contains all information required to send a UDP packet.
+//
+// This should be used as a value-only type, which exists in order to simplify
+// return value syntax. It should not be exported or extended.
+type udpPacketInfo struct {
+ route *stack.Route
+ data buffer.View
+ localPort uint16
+ remotePort uint16
+ ttl uint8
+ useDefaultTTL bool
+ tos uint8
+ owner tcpip.PacketOwner
+ noChecksum bool
+}
+
+// send sends the given packet.
+func (u *udpPacketInfo) send() (int, tcpip.Error) {
+ vv := u.data.ToVectorisedView()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
- Data: data,
+ ReserveHeaderBytes: header.UDPMinimumSize + int(u.route.MaxHeaderLength()),
+ Data: vv,
})
- pkt.Owner = owner
+ pkt.Owner = u.owner
// Initialize the UDP header.
udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
@@ -832,8 +850,8 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
- SrcPort: localPort,
- DstPort: remotePort,
+ SrcPort: u.localPort,
+ DstPort: u.remotePort,
Length: length,
})
@@ -841,30 +859,30 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.RequiresTXTransportChecksum() &&
- (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
- for _, v := range data.Views() {
+ if u.route.RequiresTXTransportChecksum() &&
+ (!u.noChecksum || u.route.NetProto() == header.IPv6ProtocolNumber) {
+ xsum := u.route.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range vv.Views() {
xsum = header.Checksum(v, xsum)
}
udp.SetChecksum(^udp.CalculateChecksum(xsum))
}
- if useDefaultTTL {
- ttl = r.DefaultTTL()
+ if u.useDefaultTTL {
+ u.ttl = u.route.DefaultTTL()
}
- if err := r.WritePacket(stack.NetworkHeaderParams{
+ if err := u.route.WritePacket(stack.NetworkHeaderParams{
Protocol: ProtocolNumber,
- TTL: ttl,
- TOS: tos,
+ TTL: u.ttl,
+ TOS: u.tos,
}, pkt); err != nil {
- r.Stats().UDP.PacketSendErrors.Increment()
- return err
+ u.route.Stats().UDP.PacketSendErrors.Increment()
+ return 0, err
}
// Track count of packets sent.
- r.Stats().UDP.PacketsSent.Increment()
- return nil
+ u.route.Stats().UDP.PacketsSent.Increment()
+ return len(u.data), nil
}
// checkV4MappedLocked determines the effective network protocol and converts