diff options
Diffstat (limited to 'pkg/tcpip')
136 files changed, 5958 insertions, 3923 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index f979d22f0..e96ba50ae 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,4 +1,5 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:deps.bzl", "deps_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -21,8 +22,9 @@ go_library( "errors.go", "sock_err_list.go", "socketops.go", + "stdclock.go", + "stdclock_state.go", "tcpip.go", - "time_unsafe.go", "timer.go", ], visibility = ["//visibility:public"], @@ -33,6 +35,36 @@ go_library( ], ) +deps_test( + name = "netstack_deps_test", + allowed = [ + "@com_github_google_btree//:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + "@org_golang_x_time//rate:go_default_library", + ], + allowed_prefixes = [ + "//", + "@org_golang_x_sys//internal/unsafeheader", + ], + targets = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", + "//pkg/tcpip/link/qdisc/fifo", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + ], +) + go_test( name = "tcpip_test", size = "small", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index fef065b05..12c39dfa3 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -53,9 +53,8 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { t.Error("Not a valid IPv4 packet") } - xsum := ipv4.CalculateChecksum() - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + if !ipv4.IsChecksumValid() { + t.Errorf("Bad checksum, got = %d", ipv4.Checksum()) } for _, f := range checkers { @@ -400,18 +399,11 @@ func TCP(checkers ...TransportChecker) NetworkChecker { t.Errorf("Bad protocol, got = %d, want = %d", p, header.TCPProtocolNumber) } - // Verify the checksum. tcp := header.TCP(last.Payload()) - l := uint16(len(tcp)) - - xsum := header.Checksum([]byte(first.SourceAddress()), 0) - xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) - xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) - xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) - xsum = header.Checksum(tcp, xsum) - - if xsum != 0 && xsum != 0xffff { - t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + payload := tcp.Payload() + payloadChecksum := header.Checksum(payload, 0) + if !tcp.IsChecksumValid(first.SourceAddress(), first.DestinationAddress(), payloadChecksum, uint16(len(payload))) { + t.Errorf("Bad checksum, got = %d", tcp.Checksum()) } // Run the transport checkers. diff --git a/pkg/tcpip/hash/jenkins/jenkins.go b/pkg/tcpip/hash/jenkins/jenkins.go index 52c22230e..33ff22a7b 100644 --- a/pkg/tcpip/hash/jenkins/jenkins.go +++ b/pkg/tcpip/hash/jenkins/jenkins.go @@ -42,26 +42,26 @@ func (s *Sum32) Reset() { *s = 0 } // Sum32 returns the hash value func (s *Sum32) Sum32() uint32 { - hash := *s + sCopy := *s - hash += (hash << 3) - hash ^= hash >> 11 - hash += hash << 15 + sCopy += sCopy << 3 + sCopy ^= sCopy >> 11 + sCopy += sCopy << 15 - return uint32(hash) + return uint32(sCopy) } // Write adds more data to the running hash. // // It never returns an error. func (s *Sum32) Write(data []byte) (int, error) { - hash := *s + sCopy := *s for _, b := range data { - hash += Sum32(b) - hash += hash << 10 - hash ^= hash >> 6 + sCopy += Sum32(b) + sCopy += sCopy << 10 + sCopy ^= sCopy >> 6 } - *s = hash + *s = sCopy return len(data), nil } diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0bdc12d53..01240f5d0 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -52,6 +52,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -69,6 +70,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 3bc8b2b21..bf9ccbf1a 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -18,6 +18,7 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) func TestIsValidUnicastEthernetAddress(t *testing.T) { @@ -142,7 +143,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { } func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { - addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") + addr := testutil.MustParse6("ff02:304:506:708:90a:b0c:d0e:f1a") if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) } diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go index b6126d29a..575604928 100644 --- a/pkg/tcpip/header/igmp_test.go +++ b/pkg/tcpip/header/igmp_test.go @@ -18,8 +18,8 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) // TestIGMPHeader tests the functions within header.igmp @@ -46,7 +46,7 @@ func TestIGMPHeader(t *testing.T) { t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) } - if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + if got, want := igmpHeader.GroupAddress(), testutil.MustParse4("1.2.3.4"); got != want { t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) } @@ -71,7 +71,7 @@ func TestIGMPHeader(t *testing.T) { t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) } - groupAddress := tcpip.Address("\x04\x03\x02\x01") + groupAddress := testutil.MustParse4("4.3.2.1") igmpHeader.SetGroupAddress(groupAddress) if got := igmpHeader.GroupAddress(); got != groupAddress { t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index f588311e0..2be21ec75 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -178,6 +178,26 @@ const ( IPv4FlagDontFragment ) +// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined +// by RFC 3927 section 1. +var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as +// defined by RFC 5771 section 4. +var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPv4EmptySubnet is the empty IPv4 subnet. var IPv4EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) @@ -423,6 +443,44 @@ func (b IPv4) IsValid(pktSize int) bool { return true } +// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4 +// link-local unicast address. +func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalUnicastSubnet.Contains(addr) +} + +// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4 +// link-local multicast address. +func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalMulticastSubnet.Contains(addr) +} + +// IsChecksumValid returns true iff the IPv4 header's checksum is valid. +func (b IPv4) IsChecksumValid() bool { + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + return b.CalculateChecksum() == 0xffff +} + // IsV4MulticastAddress determines if the provided address is an IPv4 multicast // address (range 224.0.0.0 to 239.255.255.255). The four most significant bits // will be 1110 = 0xe0. diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go index 6475cd694..c02fe898b 100644 --- a/pkg/tcpip/header/ipv4_test.go +++ b/pkg/tcpip/header/ipv4_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) { }) } } + +func TestIsV4LinkLocalUnicastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xa9\xfe\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xa9\xfe\xff\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xa9\xfd\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xa9\xff\x00\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV4LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xe0\x00\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xe0\x00\x00\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xdf\xff\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xe0\x00\x01\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index f2403978c..c3a0407ac 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,12 +98,27 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - // IPv6AllRoutersMulticastAddress is a link-local multicast group that - // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local + // multicast group that all IPv6 routers MUST join, as per RFC 4291, section + // 2.8. Packets destined to this address will reach the router on an + // interface. + // + // The address is ff01::2. + IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets // destined to this address will reach all routers on a link. // // The address is ff02::2. - IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers in a site. + // + // The address is ff05::2. + IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, // section 5: @@ -142,11 +157,6 @@ const ( // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, // within the byte holding the field, as per RFC 4291 section 2.7. ipv6MulticastAddressScopeMask = 0xF - - // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within - // a multicast IPv6 address that indicates the address has link-local scope, - // as per RFC 4291 section 2.7. - ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { return tcpip.Address(lladdrb[:]) } -// IsV6LinkLocalAddress determines if the provided address is an IPv6 -// link-local address (fe80::/10). -func IsV6LinkLocalAddress(addr tcpip.Address) bool { +// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6 +// link-local unicast address, as defined by RFC 4291 section 2.5.6. +func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool { if len(addr) != IPv6AddressSize { return false } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } -// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback -// address. +// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback +// address, as defined by RFC 4291 section 2.5.3. func IsV6LoopbackAddress(addr tcpip.Address) bool { return addr == IPv6Loopback } -// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 -// link-local multicast address. +// IsV6LinkLocalMulticastAddress returns true iff the provided address is an +// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { - return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope + return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope } // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier @@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { case IsV6LinkLocalMulticastAddress(addr): return LinkLocalScope, nil - case IsV6LinkLocalAddress(addr): + case IsV6LinkLocalUnicastAddress(addr): return LinkLocalScope, nil default: @@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) PrefixLen: IIDOffsetInIPv6Address * 8, } } + +// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by +// RFC 7346 section 2. +type IPv6MulticastScope uint8 + +// The various values for IPv6 multicast scopes, as per RFC 7346 section 2: +// +// +------+--------------------------+-------------------------+ +// | scop | NAME | REFERENCE | +// +------+--------------------------+-------------------------+ +// | 0 | Reserved | [RFC4291], RFC 7346 | +// | 1 | Interface-Local scope | [RFC4291], RFC 7346 | +// | 2 | Link-Local scope | [RFC4291], RFC 7346 | +// | 3 | Realm-Local scope | [RFC4291], RFC 7346 | +// | 4 | Admin-Local scope | [RFC4291], RFC 7346 | +// | 5 | Site-Local scope | [RFC4291], RFC 7346 | +// | 6 | Unassigned | | +// | 7 | Unassigned | | +// | 8 | Organization-Local scope | [RFC4291], RFC 7346 | +// | 9 | Unassigned | | +// | A | Unassigned | | +// | B | Unassigned | | +// | C | Unassigned | | +// | D | Unassigned | | +// | E | Global scope | [RFC4291], RFC 7346 | +// | F | Reserved | [RFC4291], RFC 7346 | +// +------+--------------------------+-------------------------+ +const ( + IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0) + IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1) + IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2) + IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3) + IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4) + IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5) + IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8) + IPv6GlobalMulticastScope = IPv6MulticastScope(0xE) + IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF) +) + +// V6MulticastScope returns the scope of a multicast address. +func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope { + return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index f10f446a6..89be84068 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -24,15 +24,17 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) -const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +var ( + linkLocalAddr = testutil.MustParse6("fe80::1") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr = testutil.MustParse6("a000::1") ) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { @@ -50,7 +52,7 @@ func TestEthernetAdddressToModifiedEUI64(t *testing.T) { } func TestLinkLocalAddr(t *testing.T) { - if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + if got, want := header.LinkLocalAddr(linkAddr), testutil.MustParse6("fe80::2:3ff:fe04:506"); got != want { t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) } } @@ -252,7 +254,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) { } } -func TestIsV6LinkLocalAddress(t *testing.T) { +func TestIsV6LinkLocalUnicastAddress(t *testing.T) { tests := []struct { name string addr tcpip.Address @@ -287,8 +289,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) } }) } @@ -373,3 +375,83 @@ func TestSolicitedNodeAddr(t *testing.T) { }) } } + +func TestV6MulticastScope(t *testing.T) { + tests := []struct { + addr tcpip.Address + want header.IPv6MulticastScope + }{ + { + addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6Reserved0MulticastScope, + }, + { + addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6InterfaceLocalMulticastScope, + }, + { + addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6LinkLocalMulticastScope, + }, + { + addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6RealmLocalMulticastScope, + }, + { + addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6AdminLocalMulticastScope, + }, + { + addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6SiteLocalMulticastScope, + }, + { + addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(6), + }, + { + addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(7), + }, + { + addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6OrganizationLocalMulticastScope, + }, + { + addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(9), + }, + { + addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(10), + }, + { + addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(11), + }, + { + addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(12), + }, + { + addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(13), + }, + { + addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6GlobalMulticastScope, + }, + { + addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6ReservedFMulticastScope, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.V6MulticastScope(test.addr); got != test.want { + t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index d0a1a2492..1b5093e58 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -26,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) // TestNDPNeighborSolicit tests the functions of NDPNeighborSolicit. @@ -40,13 +41,13 @@ func TestNDPNeighborSolicit(t *testing.T) { // Test getting the Target Address. ns := NDPNeighborSolicit(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") if got := ns.TargetAddress(); got != addr { t.Errorf("got ns.TargetAddress = %s, want %s", got, addr) } // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") ns.SetTargetAddress(addr2) if got := ns.TargetAddress(); got != addr2 { t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2) @@ -69,7 +70,7 @@ func TestNDPNeighborAdvert(t *testing.T) { // Test getting the Target Address. na := NDPNeighborAdvert(b) - addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10") + addr := testutil.MustParse6("102:304:506:708:90a:b0c:d0e:f10") if got := na.TargetAddress(); got != addr { t.Errorf("got TargetAddress = %s, want %s", got, addr) } @@ -90,7 +91,7 @@ func TestNDPNeighborAdvert(t *testing.T) { } // Test updating the Target Address. - addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11") + addr2 := testutil.MustParse6("1112:1314:1516:1718:191a:1b1c:1d1e:1f11") na.SetTargetAddress(addr2) if got := na.TargetAddress(); got != addr2 { t.Errorf("got TargetAddress = %s, want %s", got, addr2) @@ -277,7 +278,7 @@ func TestOpts(t *testing.T) { } const validLifetimeSeconds = 16909060 - const address = tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18") + address := testutil.MustParse6("90a:b0c:d0e:f10:1112:1314:1516:1718") expectedRDNSSBytes := [...]byte{ // Type, Length diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index adc835d30..0df517000 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -216,104 +216,104 @@ const ( TCPDefaultMSS = 536 ) -// SourcePort returns the "source port" field of the tcp header. +// SourcePort returns the "source port" field of the TCP header. func (b TCP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[TCPSrcPortOffset:]) } -// DestinationPort returns the "destination port" field of the tcp header. +// DestinationPort returns the "destination port" field of the TCP header. func (b TCP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[TCPDstPortOffset:]) } -// SequenceNumber returns the "sequence number" field of the tcp header. +// SequenceNumber returns the "sequence number" field of the TCP header. func (b TCP) SequenceNumber() uint32 { return binary.BigEndian.Uint32(b[TCPSeqNumOffset:]) } -// AckNumber returns the "ack number" field of the tcp header. +// AckNumber returns the "ack number" field of the TCP header. func (b TCP) AckNumber() uint32 { return binary.BigEndian.Uint32(b[TCPAckNumOffset:]) } -// DataOffset returns the "data offset" field of the tcp header. The return +// DataOffset returns the "data offset" field of the TCP header. The return // value is the length of the TCP header in bytes. func (b TCP) DataOffset() uint8 { return (b[TCPDataOffset] >> 4) * 4 } -// Payload returns the data in the tcp packet. +// Payload returns the data in the TCP packet. func (b TCP) Payload() []byte { return b[b.DataOffset():] } -// Flags returns the flags field of the tcp header. +// Flags returns the flags field of the TCP header. func (b TCP) Flags() TCPFlags { return TCPFlags(b[TCPFlagsOffset]) } -// WindowSize returns the "window size" field of the tcp header. +// WindowSize returns the "window size" field of the TCP header. func (b TCP) WindowSize() uint16 { return binary.BigEndian.Uint16(b[TCPWinSizeOffset:]) } -// Checksum returns the "checksum" field of the tcp header. +// Checksum returns the "checksum" field of the TCP header. func (b TCP) Checksum() uint16 { return binary.BigEndian.Uint16(b[TCPChecksumOffset:]) } -// UrgentPointer returns the "urgent pointer" field of the tcp header. +// UrgentPointer returns the "urgent pointer" field of the TCP header. func (b TCP) UrgentPointer() uint16 { return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:]) } -// SetSourcePort sets the "source port" field of the tcp header. +// SetSourcePort sets the "source port" field of the TCP header. func (b TCP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port) } -// SetDestinationPort sets the "destination port" field of the tcp header. +// SetDestinationPort sets the "destination port" field of the TCP header. func (b TCP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port) } -// SetChecksum sets the checksum field of the tcp header. +// SetChecksum sets the checksum field of the TCP header. func (b TCP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum) } -// SetDataOffset sets the data offset field of the tcp header. headerLen should +// SetDataOffset sets the data offset field of the TCP header. headerLen should // be the length of the TCP header in bytes. func (b TCP) SetDataOffset(headerLen uint8) { b[TCPDataOffset] = (headerLen / 4) << 4 } -// SetSequenceNumber sets the sequence number field of the tcp header. +// SetSequenceNumber sets the sequence number field of the TCP header. func (b TCP) SetSequenceNumber(seqNum uint32) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum) } -// SetAckNumber sets the ack number field of the tcp header. +// SetAckNumber sets the ack number field of the TCP header. func (b TCP) SetAckNumber(ackNum uint32) { binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum) } -// SetFlags sets the flags field of the tcp header. +// SetFlags sets the flags field of the TCP header. func (b TCP) SetFlags(flags uint8) { b[TCPFlagsOffset] = flags } -// SetWindowSize sets the window size field of the tcp header. +// SetWindowSize sets the window size field of the TCP header. func (b TCP) SetWindowSize(rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// SetUrgentPoiner sets the window size field of the tcp header. +// SetUrgentPoiner sets the window size field of the TCP header. func (b TCP) SetUrgentPoiner(urgentPointer uint16) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer) } -// CalculateChecksum calculates the checksum of the tcp segment. +// CalculateChecksum calculates the checksum of the TCP segment. // partialChecksum is the checksum of the network-layer pseudo-header // and the checksum of the segment data. func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { @@ -321,6 +321,13 @@ func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { return Checksum(b[:b.DataOffset()], partialChecksum) } +// IsChecksumValid returns true iff the TCP header's checksum is valid. +func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool { + xsum := PseudoHeaderChecksum(TCPProtocolNumber, src, dst, uint16(b.DataOffset())+payloadLength) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + // Options returns a slice that holds the unparsed TCP options in the segment. func (b TCP) Options() []byte { return b[TCPMinimumSize:b.DataOffset()] @@ -340,7 +347,7 @@ func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } -// Encode encodes all the fields of the tcp header. +// Encode encodes all the fields of the TCP header. func (b TCP) Encode(t *TCPFields) { b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], t.SrcPort) @@ -350,7 +357,7 @@ func (b TCP) Encode(t *TCPFields) { binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], t.UrgentPointer) } -// EncodePartial updates a subset of the fields of the tcp header. It is useful +// EncodePartial updates a subset of the fields of the TCP header. It is useful // in cases when similar segments are produced. func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. @@ -374,7 +381,7 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 } // ParseSynOptions parses the options received in a SYN segment and returns the -// relevant ones. opts should point to the option part of the TCP Header. +// relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { limit := len(opts) diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 98bdd29db..ae9d167ff 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -64,17 +64,17 @@ const ( UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) -// SourcePort returns the "source port" field of the udp header. +// SourcePort returns the "source port" field of the UDP header. func (b UDP) SourcePort() uint16 { return binary.BigEndian.Uint16(b[udpSrcPort:]) } -// DestinationPort returns the "destination port" field of the udp header. +// DestinationPort returns the "destination port" field of the UDP header. func (b UDP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[udpDstPort:]) } -// Length returns the "length" field of the udp header. +// Length returns the "length" field of the UDP header. func (b UDP) Length() uint16 { return binary.BigEndian.Uint16(b[udpLength:]) } @@ -84,39 +84,46 @@ func (b UDP) Payload() []byte { return b[UDPMinimumSize:] } -// Checksum returns the "checksum" field of the udp header. +// Checksum returns the "checksum" field of the UDP header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) } -// SetSourcePort sets the "source port" field of the udp header. +// SetSourcePort sets the "source port" field of the UDP header. func (b UDP) SetSourcePort(port uint16) { binary.BigEndian.PutUint16(b[udpSrcPort:], port) } -// SetDestinationPort sets the "destination port" field of the udp header. +// SetDestinationPort sets the "destination port" field of the UDP header. func (b UDP) SetDestinationPort(port uint16) { binary.BigEndian.PutUint16(b[udpDstPort:], port) } -// SetChecksum sets the "checksum" field of the udp header. +// SetChecksum sets the "checksum" field of the UDP header. func (b UDP) SetChecksum(checksum uint16) { binary.BigEndian.PutUint16(b[udpChecksum:], checksum) } -// SetLength sets the "length" field of the udp header. +// SetLength sets the "length" field of the UDP header. func (b UDP) SetLength(length uint16) { binary.BigEndian.PutUint16(b[udpLength:], length) } -// CalculateChecksum calculates the checksum of the udp packet, given the +// CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. return Checksum(b[:UDPMinimumSize], partialChecksum) } -// Encode encodes all the fields of the udp header. +// IsChecksumValid returns true iff the UDP header's checksum is valid. +func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool { + xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst, src, b.Length()) + xsum = ChecksumCombine(xsum, payloadChecksum) + return b.CalculateChecksum(xsum) == 0xffff +} + +// Encode encodes all the fields of the UDP header. func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index cd76272de..ef9126deb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -30,7 +30,6 @@ import ( type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO Route stack.RouteInfo } @@ -124,6 +123,9 @@ func (q *queue) RemoveNotify(handle *NotificationHandle) { q.notify = notify } +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.GSOEndpoint = (*Endpoint)(nil) + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -131,6 +133,7 @@ type Endpoint struct { mtu uint32 linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities + SupportedGSOKind stack.SupportedGSO // Outbound packet queue. q *queue @@ -212,11 +215,16 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.LinkEPCapabilities } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (*Endpoint) GSOMaxSize() uint32 { return 1 << 15 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + return e.SupportedGSOKind +} + // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*Endpoint) MaxHeaderLength() uint16 { @@ -229,11 +237,10 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { p := PacketInfo{ Pkt: pkt, Proto: protocol, - GSO: gso, Route: r, } @@ -243,13 +250,12 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ Pkt: pkt, Proto: protocol, - GSO: gso, Route: r, } diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index d873766a6..b427c6170 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -61,20 +61,20 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) - return e.Endpoint.WritePacket(r, gso, proto, pkt) + return e.Endpoint.WritePacket(r, proto, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) } - return e.Endpoint.WritePackets(r, gso, pkts, proto) + return e.Endpoint.WritePackets(r, pkts, proto) } // MaxHeaderLength implements stack.LinkEndpoint. diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index f042df82e..d971194e6 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/binary", "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 6be945116..bddb1d0a2 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,7 +45,6 @@ import ( "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -98,6 +97,9 @@ func (p PacketDispatchMode) String() string { } } +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + type endpoint struct { // fds is the set of file descriptors each identifying one inbound/outbound // channel. The endpoint will dispatch from all inbound channels as well as @@ -134,6 +136,9 @@ type endpoint struct { // wg keeps track of running goroutines. wg sync.WaitGroup + + // gsoKind is the supported kind of GSO. + gsoKind stack.SupportedGSO } // Options specify the details about the fd-based endpoint to be created. @@ -255,9 +260,9 @@ func New(opts *Options) (stack.LinkEndpoint, error) { if isSocket { if opts.GSOMaxSize != 0 { if opts.SoftwareGSOEnabled { - e.caps |= stack.CapabilitySoftwareGSO + e.gsoKind = stack.SWGSOSupported } else { - e.caps |= stack.CapabilityHardwareGSO + e.gsoKind = stack.HWGSOSupported } e.gsoMaxSize = opts.GSOMaxSize } @@ -403,6 +408,35 @@ type virtioNetHdr struct { csumOffset uint16 } +// marshal serializes h to a newly-allocated byte slice, in little-endian byte +// order. +// +// Note: Virtio v1.0 onwards specifies little-endian as the byte ordering used +// for general serialization. This makes it difficult to use go-marshal for +// virtio types, as go-marshal implicitly uses the native byte ordering. +func (h *virtioNetHdr) marshal() []byte { + buf := [virtioNetHdrSize]byte{ + 0: byte(h.flags), + 1: byte(h.gsoType), + + // Manually lay out the fields in little-endian byte order. Little endian => + // least significant bit goes to the lower address. + + 2: byte(h.hdrLen), + 3: byte(h.hdrLen >> 8), + + 4: byte(h.gsoSize), + 5: byte(h.gsoSize >> 8), + + 6: byte(h.csumStart), + 7: byte(h.csumStart >> 8), + + 8: byte(h.csumOffset), + 9: byte(h.csumOffset >> 8), + } + return buf[:] +} + // These constants are declared in linux/virtio_net.h. const ( _VIRTIO_NET_HDR_F_NEEDS_CSUM = 1 @@ -433,7 +467,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if e.hdrSize > 0 { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } @@ -441,29 +475,29 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} - if gso != nil { + if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) - if gso.NeedsCsum { + if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen - vnetHdr.csumOffset = gso.CsumOffset + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset } - if gso.Type != stack.GSONone && uint16(pkt.Data().Size()) > gso.MSS { - switch gso.Type { + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 case stack.GSOTCPv6: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 default: - panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) } - vnetHdr.gsoSize = gso.MSS + vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf := vnetHdr.marshal() builder.Add(vnetHdrBuf) } @@ -482,9 +516,9 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp } var vnetHdrBuf []byte - if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} - if pkt.GSOOptions != nil { + if pkt.GSOOptions.Type != stack.GSONone { vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM @@ -503,7 +537,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) + vnetHdrBuf = vnetHdr.marshal() } var builder iovec.Builder @@ -540,7 +574,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcp // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { // Preallocate to avoid repeated reallocation as we append to batch. // batchSz is 47 because when SWGSO is in use then a single 65KB TCP // segment can get split into 46 segments of 1420 bytes and a single 216 @@ -602,11 +636,16 @@ func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { } } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// SupportsHWGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + return e.gsoKind +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (e *endpoint) ARPHardwareType() header.ARPHardwareType { if e.hdrSize > 0 { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 1e40f3fef..8aad338b6 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -207,18 +207,17 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u // Write. want := append(append(buffer.View(nil), b...), payload...) - var gso *stack.GSO + const l3HdrLen = header.IPv6MinimumSize if gsoMaxSize != 0 { - gso = &stack.GSO{ + pkt.GSOOptions = stack.GSO{ Type: stack.GSOTCPv6, NeedsCsum: true, CsumOffset: csumOffset, MSS: gsoMSS, - MaxSize: gsoMaxSize, - L3HdrLen: header.IPv4MaximumHeaderSize, + L3HdrLen: l3HdrLen, } } - if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -235,7 +234,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u if vnetHdr.flags&_VIRTIO_NET_HDR_F_NEEDS_CSUM == 0 { t.Fatalf("virtioNetHdr.flags %v doesn't contain %v", vnetHdr.flags, _VIRTIO_NET_HDR_F_NEEDS_CSUM) } - csumStart := header.EthernetMinimumSize + gso.L3HdrLen + const csumStart = header.EthernetMinimumSize + l3HdrLen if vnetHdr.csumStart != csumStart { t.Fatalf("vnetHdr.csumStart = %v, want %v", vnetHdr.csumStart, csumStart) } @@ -243,7 +242,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u t.Fatalf("vnetHdr.csumOffset = %v, want %v", vnetHdr.csumOffset, csumOffset) } gsoType := uint8(0) - if int(gso.MSS) < plen { + if plen > gsoMSS { gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 } if vnetHdr.gsoType != gsoType { @@ -333,7 +332,7 @@ func TestPreserveSrcAddress(t *testing.T) { ReserveHeaderBytes: header.EthernetMinimumSize, Data: buffer.VectorisedView{}, }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index a7adf822b..4b7ef3aac 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -128,7 +128,7 @@ type readVDispatcher struct { func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { d := &readVDispatcher{fd: fd, e: e} - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil } @@ -212,7 +212,7 @@ func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { bufs: make([]*iovecBuffer, MaxMsgsPerRecv), msgHdrs: make([]rawfile.MMsgHdr, MaxMsgsPerRecv), } - skipsVnetHdr := d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 + skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported for i := range d.bufs { d.bufs[i] = newIovecBuffer(BufConfig, skipsVnetHdr) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 691467870..7012d8829 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(_ stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { // Construct data as the unparsed portion for the loopback packet. data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 668f72eee..3e2a1aa94 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,20 +87,20 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, &tcpip.ErrNoRoute{} } - return endpoint.WritePackets(r, gso, pkts, protocol) + return endpoint.WritePackets(r, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, gso, protocol, pkt) + return endpoint.WritePacket(r, protocol, pkt) } return &tcpip.ErrNoRoute{} } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 5806f7fdf..040e3a35b 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -54,7 +54,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { var packetRoute stack.RouteInfo packetRoute.RemoteAddress = dstIP - endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) + endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -76,7 +76,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { pkt.TransportHeader().Push(1)[0] = 0xFA var packetRoute stack.RouteInfo packetRoute.RemoteAddress = dstIP - endpoint.WritePacket(packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) + endpoint.WritePacket(packetRoute, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 97ad9fdd5..3e816b0c7 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -113,13 +113,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return e.child.WritePacket(r, gso, protocol, pkt) +func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + return e.child.WritePacket(r, protocol, pkt) } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return e.child.WritePackets(r, gso, pkts, protocol) +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return e.child.WritePackets(r, pkts, protocol) } // Wait implements stack.LinkEndpoint. @@ -135,6 +135,14 @@ func (e *Endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *Endpoint) SupportedGSO() stack.SupportedGSO { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.SupportedGSO() + } + return stack.GSONotSupported +} + // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.child.ARPHardwareType() diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index 6cbe18a56..e01837e2d 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -35,16 +35,16 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) - return e.Endpoint.WritePacket(r, gso, protocol, pkt) + return e.Endpoint.WritePacket(r, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } - return e.Endpoint.WritePackets(r, gso, pkts, proto) + return e.Endpoint.WritePackets(r, pkts, proto) } diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 21fb87757..5030b6ba1 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -66,7 +66,7 @@ func (e *Endpoint) deliverPackets(r stack.RouteInfo, proto tcpip.NetworkProtocol } // WritePacket implements stack.LinkEndpoint. -func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var pkts stack.PacketBufferList pkts.PushBack(pkt) e.deliverPackets(r, proto, pkts) @@ -74,7 +74,7 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, proto tcpip.Netw } // WritePackets implements stack.LinkEndpoint. -func (e *Endpoint) WritePackets(r stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := pkts.Len() e.deliverPackets(r, proto, pkts) return n, nil diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 128ef6e87..b1a28491d 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -25,6 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.GSOEndpoint = (*endpoint)(nil) + // endpoint represents a LinkEndpoint which implements a FIFO queue for all // outgoing packets. endpoint can have 1 or more underlying queueDispatchers. // All outgoing packets are consistenly hashed to a single underlying queue @@ -91,7 +94,7 @@ func (q *queueDispatcher) dispatchLoop() { } // We pass a protocol of zero here because each packet carries its // NetworkProtocol. - q.lower.WritePackets(stack.RouteInfo{}, nil /* gso */, batch, 0 /* protocol */) + q.lower.WritePackets(stack.RouteInfo{}, batch, 0 /* protocol */) for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() { batch.Remove(pkt) } @@ -141,7 +144,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.lower.LinkAddress() } -// GSOMaxSize returns the maximum GSO packet size. +// GSOMaxSize implements stack.GSOEndpoint. func (e *endpoint) GSOMaxSize() uint32 { if gso, ok := e.lower.(stack.GSOEndpoint); ok { return gso.GSOMaxSize() @@ -149,13 +152,21 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// SupportedGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.SupportedGSO() + } + return stack.GSONotSupported +} + // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - // WritePacket caller's do not set the following fields in PacketBuffer - // so we populate them here. - pkt.EgressRoute = r - pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = protocol +// +// The packet must have the following fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] if !d.q.enqueue(pkt) { return &tcpip.ErrNoBufferSpace{} @@ -166,12 +177,12 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. // -// Being a batch API, each packet in pkts should have the following -// fields populated: +// Each packet in the packet buffer list must have the following fields +// populated: // - pkt.EgressRoute // - pkt.GSOOptions // - pkt.NetworkProtocolNumber -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { enqueued := 0 for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index d8d0b16b2..df9a0b90a 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -203,7 +203,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) views := pkt.Views() @@ -220,7 +220,7 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, _ *stack.GSO, protocol tcpip.N } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*endpoint) WritePackets(stack.RouteInfo, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (*endpoint) WritePackets(stack.RouteInfo, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index d4b3ddd5c..0f72d4e95 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -281,7 +281,7 @@ func TestSimpleSend(t *testing.T) { copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -351,7 +351,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(r, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -407,7 +407,7 @@ func TestFillTxQueue(t *testing.T) { Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -425,7 +425,7 @@ func TestFillTxQueue(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -453,7 +453,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -476,7 +476,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -494,7 +494,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -520,7 +520,7 @@ func TestFillTxMemory(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -539,7 +539,7 @@ func TestFillTxMemory(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -566,7 +566,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -581,7 +581,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buffer.NewView(bufferSize).ToVectorisedView(), }) - err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt) + err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt) if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Fatalf("got WritePacket(...) = %v, want %s", err, &tcpip.ErrWouldBlock{}) } @@ -593,7 +593,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), Data: buf.ToVectorisedView(), }) - if err := c.ep.WritePacket(r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := c.ep.WritePacket(r, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 7aaee3d13..2d6a3a833 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -139,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dumpPacket(directionRecv, nil, protocol, pkt) + e.dumpPacket(directionRecv, protocol, pkt) e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -148,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(e.logPrefix, dir, protocol, pkt, gso) + logPacket(e.logPrefix, dir, protocol, pkt) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Size() @@ -187,22 +187,22 @@ func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - e.dumpPacket(directionSend, gso, protocol, pkt) - return e.Endpoint.WritePacket(r, gso, protocol, pkt) +func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + e.dumpPacket(directionSend, protocol, pkt) + return e.Endpoint.WritePacket(r, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket(directionSend, gso, protocol, pkt) + e.dumpPacket(directionSend, protocol, pkt) } - return e.Endpoint.WritePackets(r, gso, pkts, protocol) + return e.Endpoint.WritePackets(r, pkts, protocol) } -func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -411,8 +411,8 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe return } - if gso != nil { - details += fmt.Sprintf(" gso: %+v", gso) + if pkt.GSOOptions.Type != stack.GSONone { + details += fmt.Sprintf(" gso: %#v", pkt.GSOOptions) } log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index ce5113746..a95602aa5 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -108,12 +108,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (e *Endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, protocol, pkt) + err := e.lower.WritePacket(r, protocol, pkt) e.writeGate.Leave() return err } @@ -121,12 +121,12 @@ func (e *Endpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *Endpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { if !e.writeGate.Enter() { return pkts.Len(), nil } - n, err := e.lower.WritePackets(r, gso, pkts, protocol) + n, err := e.lower.WritePackets(r, pkts, protocol) e.writeGate.Leave() return n, err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index e368a9eaa..a71400ee9 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -69,13 +69,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (e *countedEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *countedEndpoint) WritePackets(_ stack.RouteInfo, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, tcpip.Error) { e.writeCount += pkts.Len() return pkts.Len(), nil } @@ -98,21 +98,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(stack.RouteInfo{}, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) + wep.WritePacket(stack.RouteInfo{}, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index fa8814bac..7b1ff44f4 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -21,6 +21,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index d59d678b2..6905b9ccb 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -33,6 +33,7 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7ae38d684..0efa3a926 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -136,7 +136,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (*endpoint) Close() {} -func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { +func (*endpoint) WritePacket(*stack.Route, stack.NetworkHeaderParams, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -146,7 +146,7 @@ func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*endpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { return 0, &tcpip.ErrNotSupported{} } @@ -222,7 +222,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // // Send the packet to the (new) target hardware address on the same // hardware on which the request was received. - if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt); err != nil { + if err := e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), ProtocolNumber, respPkt); err != nil { stats.outgoingRepliesDropped.Increment() } else { stats.outgoingRepliesSent.Increment() @@ -355,7 +355,7 @@ func (e *endpoint) sendARPRequest(localAddr, targetAddr tcpip.Address, remoteLin } stats := e.stats.arp - if err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacketToRemote(remoteLinkAddr, ProtocolNumber, pkt); err != nil { stats.outgoingRequestsDropped.Increment() return err } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 018d6a578..94209b026 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -30,20 +30,16 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( nicID = 1 - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - unknownAddr = tcpip.Address("\x0a\x00\x00\x03") - defaultChannelSize = 1 defaultMTU = 65536 @@ -54,6 +50,12 @@ const ( eventChanSize = 32 ) +var ( + stackAddr = testutil.MustParse4("10.0.0.1") + remoteAddr = testutil.MustParse4("10.0.0.2") + unknownAddr = testutil.MustParse4("10.0.0.3") +) + type eventType uint8 const ( @@ -449,12 +451,12 @@ type testLinkEndpoint struct { writeErr tcpip.Error } -func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testLinkEndpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if t.writeErr != nil { return t.writeErr } - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) + return t.LinkEndpoint.WritePacket(r, protocol, pkt) } func TestLinkAddressRequest(t *testing.T) { diff --git a/pkg/tcpip/network/internal/ip/BUILD b/pkg/tcpip/network/internal/ip/BUILD index d21b4c7ef..fd944ce99 100644 --- a/pkg/tcpip/network/internal/ip/BUILD +++ b/pkg/tcpip/network/internal/ip/BUILD @@ -6,6 +6,7 @@ go_library( name = "ip", srcs = [ "duplicate_address_detection.go", + "errors.go", "generic_multicast_protocol.go", "stats.go", ], diff --git a/pkg/tcpip/network/internal/ip/errors.go b/pkg/tcpip/network/internal/ip/errors.go new file mode 100644 index 000000000..50fabfd79 --- /dev/null +++ b/pkg/tcpip/network/internal/ip/errors.go @@ -0,0 +1,77 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// ForwardingError represents an error that occured while trying to forward +// a packet. +type ForwardingError interface { + isForwardingError() + fmt.Stringer +} + +// ErrTTLExceeded indicates that the received packet's TTL has been exceeded. +type ErrTTLExceeded struct{} + +func (*ErrTTLExceeded) isForwardingError() {} + +func (*ErrTTLExceeded) String() string { return "ttl exceeded" } + +// ErrIPOptProblem indicates the received packet had a problem with an IP +// option. +type ErrIPOptProblem struct{} + +func (*ErrIPOptProblem) isForwardingError() {} + +func (*ErrIPOptProblem) String() string { return "ip option problem" } + +// ErrLinkLocalSourceAddress indicates the received packet had a link-local +// source address. +type ErrLinkLocalSourceAddress struct{} + +func (*ErrLinkLocalSourceAddress) isForwardingError() {} + +func (*ErrLinkLocalSourceAddress) String() string { return "link local destination address" } + +// ErrLinkLocalDestinationAddress indicates the received packet had a link-local +// destination address. +type ErrLinkLocalDestinationAddress struct{} + +func (*ErrLinkLocalDestinationAddress) isForwardingError() {} + +func (*ErrLinkLocalDestinationAddress) String() string { return "link local destination address" } + +// ErrNoRoute indicates the Netstack couldn't find a route for the +// received packet. +type ErrNoRoute struct{} + +func (*ErrNoRoute) isForwardingError() {} + +func (*ErrNoRoute) String() string { return "no route" } + +// ErrOther indicates the packet coould not be forwarded for a reason +// captured by the contained error. +type ErrOther struct { + Err tcpip.Error +} + +func (*ErrOther) isForwardingError() {} + +func (e *ErrOther) String() string { return fmt.Sprintf("other tcpip error: %s", e.Err) } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index b9f129728..d22974b12 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ip holds IPv4/IPv6 common utilities. package ip import ( @@ -156,14 +155,6 @@ type GenericMulticastProtocolOptions struct { // // Unsolicited reports are transmitted when a group is newly joined. MaxUnsolicitedReportDelay time.Duration - - // AllNodesAddress is a multicast address that all nodes on a network should - // be a member of. - // - // This address will not have the generic multicast protocol performed on it; - // it will be left in the non member/listener state, and packets will never - // be sent for it. - AllNodesAddress tcpip.Address } // MulticastGroupProtocol is a multicast group protocol whose core state machine @@ -188,6 +179,10 @@ type MulticastGroupProtocol interface { // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) tcpip.Error + + // ShouldPerformProtocol returns true iff the protocol should be performed for + // the specified group. + ShouldPerformProtocol(tcpip.Address) bool } // GenericMulticastProtocolState is the per interface generic multicast protocol @@ -455,20 +450,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t info.lastToSendReport = false - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { info.state = idleMember return } @@ -537,20 +519,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } @@ -627,20 +596,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 381460c82..0b51563cd 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct { type mockMulticastGroupProtocol struct { t *testing.T + skipProtocolAddress tcpip.Address + mu mockMulticastGroupProtocolProtectedFields } @@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip return nil } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + return groupAddress != m.skipProtocolAddress +} + func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { m.mu.Lock() defer m.mu.Unlock() @@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr cmp.FilterPath( func(p cmp.Path) bool { switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": return true + default: + return false } - return false }, cmp.Ignore(), ), @@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after @@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) mgp.joinGroup(test.addr) @@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) { } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index b6f39ddb1..392f0b0c7 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -18,70 +18,114 @@ import "gvisor.dev/gvisor/pkg/tcpip" // LINT.IfChange(MultiCounterIPStats) +// MultiCounterIPForwardingStats holds IP forwarding statistics. Each counter +// may have several versions. +type MultiCounterIPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable tcpip.MultiCounterStat + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL tcpip.MultiCounterStat + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource tcpip.MultiCounterStat + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination tcpip.MultiCounterStat + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors tcpip.MultiCounterStat +} + // MultiCounterIPStats holds IP statistics, each counter may have several // versions. type MultiCounterIPStats struct { - // PacketsReceived is the total number of IP packets received from the link + // PacketsReceived is the number of IP packets received from the link // layer. PacketsReceived tcpip.MultiCounterStat - // DisabledPacketsReceived is the total number of IP packets received from the - // link layer when the IP layer is disabled. + // DisabledPacketsReceived is the number of IP packets received from + // the link layer when the IP layer is disabled. DisabledPacketsReceived tcpip.MultiCounterStat - // InvalidDestinationAddressesReceived is the total number of IP packets + // InvalidDestinationAddressesReceived is the number of IP packets // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived tcpip.MultiCounterStat - // InvalidSourceAddressesReceived is the total number of IP packets received - // with a source address that should never have been received on the wire. + // InvalidSourceAddressesReceived is the number of IP packets received + // with a source address that should never have been received on the + // wire. InvalidSourceAddressesReceived tcpip.MultiCounterStat - // PacketsDelivered is the total number of incoming IP packets that are + // PacketsDelivered is the number of incoming IP packets that are // successfully delivered to the transport layer. PacketsDelivered tcpip.MultiCounterStat - // PacketsSent is the total number of IP packets sent via WritePacket. + // PacketsSent is the number of IP packets sent via WritePacket. PacketsSent tcpip.MultiCounterStat - // OutgoingPacketErrors is the total number of IP packets which failed to + // OutgoingPacketErrors is the number of IP packets which failed to // write to a link-layer endpoint. OutgoingPacketErrors tcpip.MultiCounterStat - // MalformedPacketsReceived is the total number of IP Packets that were + // MalformedPacketsReceived is the number of IP Packets that were // dropped due to the IP packet header failing validation checks. MalformedPacketsReceived tcpip.MultiCounterStat - // MalformedFragmentsReceived is the total number of IP Fragments that were + // MalformedFragmentsReceived is the number of IP Fragments that were // dropped due to the fragment failing validation checks. MalformedFragmentsReceived tcpip.MultiCounterStat - // IPTablesPreroutingDropped is the total number of IP packets dropped in the + // IPTablesPreroutingDropped is the number of IP packets dropped in the // Prerouting chain. IPTablesPreroutingDropped tcpip.MultiCounterStat - // IPTablesInputDropped is the total number of IP packets dropped in the Input - // chain. + // IPTablesInputDropped is the number of IP packets dropped in the + // Input chain. IPTablesInputDropped tcpip.MultiCounterStat - // IPTablesOutputDropped is the total number of IP packets dropped in the + // IPTablesOutputDropped is the number of IP packets dropped in the // Output chain. IPTablesOutputDropped tcpip.MultiCounterStat - // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out - // of IPStats. + // IPTablesPostroutingDropped is the number of IP packets dropped in + // the Postrouting chain. + IPTablesPostroutingDropped tcpip.MultiCounterStat + + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option + // stats out of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. OptionTimestampReceived tcpip.MultiCounterStat - // OptionRecordRouteReceived is the number of Record Route options seen. + // OptionRecordRouteReceived is the number of Record Route options + // seen. OptionRecordRouteReceived tcpip.MultiCounterStat - // OptionRouterAlertReceived is the number of Router Alert options seen. + // OptionRouterAlertReceived is the number of Router Alert options + // seen. OptionRouterAlertReceived tcpip.MultiCounterStat // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived tcpip.MultiCounterStat + + // Forwarding collects stats related to IP forwarding. + Forwarding MultiCounterIPForwardingStats +} + +// Init sets internal counters to track a and b counters. +func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { + m.Unrouteable.Init(a.Unrouteable, b.Unrouteable) + m.Errors.Init(a.Errors, b.Errors) + m.LinkLocalSource.Init(a.LinkLocalSource, b.LinkLocalSource) + m.LinkLocalDestination.Init(a.LinkLocalDestination, b.LinkLocalDestination) + m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) } // Init sets internal counters to track a and b counters. @@ -98,10 +142,12 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) { m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped) m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped) m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped) + m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped) m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived) m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived) m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived) m.OptionUnknownReceived.Init(a.OptionUnknownReceived, b.OptionUnknownReceived) + m.Forwarding.Init(&a.Forwarding, &b.Forwarding) } // LINT.ThenChange(:MultiCounterIPStats, ../../../tcpip.go:IPStats) diff --git a/pkg/tcpip/network/internal/testutil/testutil.go b/pkg/tcpip/network/internal/testutil/testutil.go index f5fa77b65..e2cf24b67 100644 --- a/pkg/tcpip/network/internal/testutil/testutil.go +++ b/pkg/tcpip/network/internal/testutil/testutil.go @@ -64,7 +64,7 @@ func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 } func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } // WritePacket implements LinkEndpoint.WritePacket. -func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { if ep.allowPackets == 0 { return ep.err } @@ -74,11 +74,11 @@ func (ep *MockLinkEndpoint) WritePacket(_ stack.RouteInfo, _ *stack.GSO, _ tcpip } // WritePackets implements LinkEndpoint.WritePackets. -func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (ep *MockLinkEndpoint) WritePackets(r stack.RouteInfo, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { var n int for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := ep.WritePacket(r, gso, protocol, pkt); err != nil { + if err := ep.WritePacket(r, protocol, pkt); err != nil { return n, err } n++ diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a4edc69c7..74aad126c 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "fmt" "strings" "testing" @@ -29,23 +30,25 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -const ( - localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") - remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") - ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") - ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") - ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") - localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") - ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - nicID = 1 +const nicID = 1 + +var ( + localIPv4Addr = testutil.MustParse4("10.0.0.1") + remoteIPv4Addr = testutil.MustParse4("10.0.0.2") + ipv4SubnetAddr = testutil.MustParse4("10.0.0.0") + ipv4SubnetMask = testutil.MustParse4("255.255.255.0") + ipv4Gateway = testutil.MustParse4("10.0.0.3") + localIPv6Addr = testutil.MustParse6("a00::1") + remoteIPv6Addr = testutil.MustParse6("a00::2") + ipv6SubnetAddr = testutil.MustParse6("a00::") + ipv6SubnetMask = testutil.MustParse6("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ff00") + ipv6Gateway = testutil.MustParse6("a00::3") ) var localIPv4AddrWithPrefix = tcpip.AddressWithPrefix{ @@ -180,7 +183,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -202,7 +205,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { panic("not implemented") } @@ -323,7 +326,7 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -588,7 +591,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -1015,7 +1018,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -1938,3 +1941,80 @@ func TestICMPInclusionSize(t *testing.T) { }) } } + +func TestJoinLeaveAllRoutersGroup(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + protoFactory stack.NetworkProtocolFactory + allRoutersAddr tcpip.Address + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + protoFactory: ipv4.NewProtocol, + allRoutersAddr: header.IPv4AllRoutersGroup, + }, + { + name: "IPv6 Interface Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, + }, + { + name: "IPv6 Link Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, + }, + { + name: "IPv6 Site Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, nicDisabled := range [...]bool{true, false} { + t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + opts := stack.NICOptions{Disabled: nicDisabled} + if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) + } + + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if !got { + t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, false); err != nil { + t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 5e7f10f4b..7ee0495d9 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -45,6 +45,7 @@ go_test( "//pkg/tcpip/network/internal/testutil", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1525f15db..c8ed1ce79 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -163,10 +163,12 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet return } - // Skip the ip header, then deliver the error. - pkt.Data().TrimFront(hlen) + // Keep needed information before trimming header. p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, errInfo, pkt) + dstAddr := hdr.DestinationAddress() + // Skip the ip header, then deliver the error. + pkt.Data().DeleteFront(hlen) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, errInfo, pkt) } func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { @@ -336,14 +338,16 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4DstUnreachable: received.dstUnreachable.Increment() - pkt.Data().TrimFront(header.ICMPv4MinimumSize) - switch h.Code() { + mtu := h.MTU() + code := h.Code() + pkt.Data().DeleteFront(header.ICMPv4MinimumSize) + switch code { case header.ICMPv4HostUnreachable: e.handleControl(&icmpv4DestinationHostUnreachableSockError{}, pkt) case header.ICMPv4PortUnreachable: e.handleControl(&icmpv4DestinationPortUnreachableSockError{}, pkt) case header.ICMPv4FragmentationNeeded: - networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + networkMTU, err := calculateNetworkMTU(uint32(mtu), header.IPv4MinimumSize) if err != nil { networkMTU = 0 } @@ -442,6 +446,23 @@ func (r *icmpReasonParamProblem) isForwarding() bool { return r.forwarding } +// icmpReasonNetworkUnreachable is an error in which the network specified in +// the internet destination field of the datagram is unreachable. +type icmpReasonNetworkUnreachable struct{} + +func (*icmpReasonNetworkUnreachable) isICMPReason() {} +func (*icmpReasonNetworkUnreachable) isForwarding() bool { + // If we hit a Net Unreachable error, then we know we are operating as + // a router. As per RFC 792 page 5, Destination Unreachable Message, + // + // If, according to the information in the gateway's routing tables, + // the network specified in the internet destination field of a + // datagram is unreachable, e.g., the distance to the network is + // infinity, the gateway may send a destination unreachable message to + // the internet source host of the datagram. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -610,6 +631,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetworkUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4NetUnreachable) + counter = sent.dstUnreachable case *icmpReasonTTLExceeded: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4TTLExceeded) @@ -629,7 +654,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( - nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, TTL: route.DefaultTTL(), diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index f3fc1c87e..3ce499298 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2236 section 6 page 10, + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + return groupAddress != header.IPv4AllSystems +} + // init sets up an igmpState struct, and is required to be called before using // a new igmpState. // @@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: igmp, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { @@ -331,7 +341,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip } sentStats := igmp.ep.stats.igmp.packetsSent - if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), ProtocolNumber, pkt); err != nil { sentStats.dropped.Increment() return false, err } diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index e5e1b89cc..4bd6f462e 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -26,18 +26,22 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackAddr = tcpip.Address("\x0a\x00\x00\x01") - remoteAddr = tcpip.Address("\x0a\x00\x00\x02") - multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 defaultTTL = 1 defaultPrefixLength = 24 ) +var ( + stackAddr = testutil.MustParse4("10.0.0.1") + remoteAddr = testutil.MustParse4("10.0.0.2") + multicastAddr = testutil.MustParse4("224.0.0.3") +) + // validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet // sent to the provided address with the passed fields set. Raises a t.Error if // any field does not match. @@ -292,7 +296,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPLeaveGroup, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() }, @@ -302,7 +306,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPMembershipQuery, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: true, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() }, @@ -312,7 +316,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv1MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() }, @@ -322,7 +326,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv2MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}}, - srcAddr: tcpip.Address("\x0a\x00\x01\x02"), + srcAddr: testutil.MustParse4("10.0.1.2"), ttl: 1, expectValidIGMP: false, getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() }, @@ -332,7 +336,7 @@ func TestIGMPPacketValidation(t *testing.T) { messageType: header.IGMPv2MembershipReport, includeRouterAlertOption: true, stackAddresses: []tcpip.AddressWithPrefix{ - {Address: tcpip.Address("\x0a\x00\x0f\x01"), PrefixLen: 24}, + {Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24}, {Address: stackAddr, PrefixLen: 24}, }, srcAddr: remoteAddr, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1a5661ca4..b11e56c6a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation" + "gvisor.dev/gvisor/pkg/tcpip/network/internal/ip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -150,6 +151,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if forwarding { + // There does not seem to be an RFC requirement for a node to join the all + // routers multicast address but + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + // specifies the address as a group for all routers on a subnet so we join + // the group here. + if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } + + return + } + + switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() @@ -226,7 +259,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) @@ -318,7 +351,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -338,7 +371,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := e.addIPHeader(r.LocalAddress(), r.RemoteAddress(), pkt, params, nil /* options */); err != nil { return err } @@ -346,7 +379,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -369,10 +402,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } } - return e.writePacket(r, gso, pkt, false /* headerIncluded */) + return e.writePacket(r, pkt, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error { if r.Loop()&stack.PacketLoop != 0 { // If the packet was generated by the stack (not a raw/packet endpoint // where a packet may be written with the header included), then we can @@ -383,6 +416,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + // Postrouting NAT can only change the source address, and does not alter the + // route or outgoing interface of the packet. + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesPostroutingDropped.Increment() + return nil + } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) @@ -391,20 +433,20 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return err } - if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if packetMustBeFragmented(pkt, networkMTU) { + sent, remain, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. - return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + return e.nic.WritePacket(r, ProtocolNumber, fragPkt) }) stats.PacketsSent.IncrementBy(uint64(sent)) stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, ProtocolNumber, pkt); err != nil { stats.OutgoingPacketErrors.Increment() return err } @@ -413,7 +455,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -434,11 +476,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return 0, err } - if packetMustBeFragmented(pkt, networkMTU, gso) { + if packetMustBeFragmented(pkt, networkMTU) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if _, _, err := e.handleFragments(r, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt @@ -454,9 +496,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName) - stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - for pkt := range dropped { + outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) + for pkt := range outputDropped { pkts.Remove(pkt) } @@ -478,14 +520,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } + // We ignore the list of NAT-ed packets here because Postrouting NAT can only + // change the source address, and does not alter the route or outgoing + // interface of the packet. + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) + for pkt := range postroutingDropped { + pkts.Remove(pkt) + } + // The rest of the packets can be delivered to the NIC as a batch. pktsLen := pkts.Len() - written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + written, err := e.nic.WritePackets(r, pkts, ProtocolNumber) stats.PacketsSent.IncrementBy(uint64(written)) stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) // Dropped packets aren't errors, so include them in the return value. - return locallyDelivered + written + len(dropped), err + return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -545,12 +596,31 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return &tcpip.ErrMalformedHeader{} } - return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) + return e.writePacket(r, pkt, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv4(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} + } + ttl := h.TTL() if ttl == 0 { // As per RFC 792 page 6, Time Exceeded Message, @@ -558,7 +628,12 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // If the gateway processing a datagram finds the time to live field // is zero it must discard the datagram. The gateway may also notify // the source host via the time exceeded message. - return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } if opts := h.Options(); len(opts) != 0 { @@ -569,10 +644,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { pointer: optProblem.Pointer, forwarding: true, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - e.stats.ip.MalformedPacketsReceived.Increment() } - return nil // option problems are not reported locally. + return &ip.ErrIPOptProblem{} } copied := copy(opts, newOpts) if copied != len(newOpts) { @@ -589,8 +662,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -598,8 +669,16 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetworkUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() @@ -616,10 +695,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // spent, the field must be decremented by 1. newHdr.SetTTL(ttl - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + })); err != nil { + return &ip.ErrOther{Err: err} + } + return nil } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -668,7 +750,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -734,14 +816,31 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) stats.ip.InvalidDestinationAddressesReceived.Increment() return } - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + stats.ip.Forwarding.Unrouteable.Increment() + case *ip.ErrIPOptProblem: + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + stats.ip.MalformedPacketsReceived.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + stats.ip.Forwarding.Errors.Increment() return } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -1114,28 +1213,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return nil, false } - // There has been some confusion regarding verifying checksums. We need - // just look for negative 0 (0xffff) as the checksum, as it's not possible to - // get positive 0 (0) for the checksum. Some bad implementations could get it - // when doing entry replacement in the early days of the Internet, - // however the lore that one needs to check for both persists. - // - // RFC 1624 section 1 describes the source of this confusion as: - // [the partial recalculation method described in RFC 1071] computes a - // result for certain cases that differs from the one obtained from - // scratch (one's complement of one's complement sum of the original - // fields). - // - // However RFC 1624 section 5 clarifies that if using the verification method - // "recommended by RFC 1071, it does not matter if an intermediate system - // generated a -0 instead of +0". - // - // RFC1071 page 1 specifies the verification method as: - // (3) To check a checksum, the 1's complement sum is computed over the - // same set of octets, including the checksum field. If the result - // is all 1 bits (-0 in 1's complement arithmetic), the check - // succeeds. - if h.CalculateChecksum() != 0xffff { + if !h.IsChecksumValid() { return nil, false } @@ -1168,12 +1246,27 @@ func (p *protocol) Forwarding() bool { return uint8(atomic.LoadUint32(&p.forwarding)) == 1 } +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) + } + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) +} + // SetForwarding implements stack.ForwardingNetworkProtocol. func (p *protocol) SetForwarding(v bool) { - if v { - atomic.StoreUint32(&p.forwarding, 1) - } else { - atomic.StoreUint32(&p.forwarding, 0) + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for _, ep := range p.mu.eps { + ep.transitionForwarding(v) } } @@ -1200,9 +1293,9 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error return networkMTU - uint32(networkHeaderSize), nil } -func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() - return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU + return pkt.GSOOptions.Type == stack.GSONone && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index eba91c68c..7a7cad04a 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -130,48 +131,69 @@ func TestForwarding(t *testing.T) { } remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4()) + multicastIPv4Addr := tcpip.Address(net.ParseIP("225.0.0.0").To4()) + linkLocalIPv4Addr := tcpip.Address(net.ParseIP("169.254.0.0").To4()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool - options header.IPv4Options - forwardedOptions header.IPv4Options - icmpType header.ICMPv4Type - icmpCode header.ICMPv4Code + name string + TTL uint8 + sourceAddr tcpip.Address + destAddr tcpip.Address + expectErrorICMP bool + expectPacketForwarded bool + options header.IPv4Options + forwardedOptions header.IPv4Options + icmpType header.ICMPv4Type + icmpCode header.ICMPv4Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool }{ { name: "TTL of zero", TTL: 0, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, expectErrorICMP: true, icmpType: header.ICMPv4TimeExceeded, icmpCode: header.ICMPv4TTLExceeded, }, { - name: "TTL of one", - TTL: 1, - expectErrorICMP: false, + name: "TTL of one", + TTL: 1, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "TTL of two", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, }, { - name: "four EOL options", - TTL: 2, - expectErrorICMP: false, - options: header.IPv4Options{0, 0, 0, 0}, - forwardedOptions: header.IPv4Options{0, 0, 0, 0}, + name: "four EOL options", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, + expectPacketForwarded: true, + options: header.IPv4Options{0, 0, 0, 0}, + forwardedOptions: header.IPv4Options{0, 0, 0, 0}, }, { - name: "TS type 1 full", - TTL: 2, + name: "TS type 1 full", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, options: header.IPv4Options{ 68, 12, 13, 0xF1, 192, 168, 1, 12, @@ -182,8 +204,10 @@ func TestForwarding(t *testing.T) { icmpCode: header.ICMPv4UnusedCode, }, { - name: "TS type 0", - TTL: 2, + name: "TS type 0", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, options: header.IPv4Options{ 68, 24, 21, 0x00, 1, 2, 3, 4, @@ -200,10 +224,13 @@ func TestForwarding(t *testing.T) { 13, 14, 15, 16, 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock }, + expectPacketForwarded: true, }, { - name: "end of options list", - TTL: 2, + name: "end of options list", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: remoteIPv4Addr2, options: header.IPv4Options{ 68, 12, 13, 0x11, 192, 168, 1, 12, @@ -219,6 +246,37 @@ func TestForwarding(t *testing.T) { 0, 0, 0, // 7 bytes unknown option removed. 0, 0, 0, 0, }, + expectPacketForwarded: true, + }, + { + name: "Network unreachable", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: unreachableIPv4Addr, + expectErrorICMP: true, + icmpType: header.ICMPv4DstUnreachable, + icmpCode: header.ICMPv4NetUnreachable, + expectPacketUnrouteableError: true, + }, + { + name: "Multicast destination", + TTL: 2, + destAddr: multicastIPv4Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv4Addr1, + destAddr: linkLocalIPv4Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv4Addr, + destAddr: remoteIPv4Addr2, + expectLinkLocalSourceError: true, }, } for _, test := range tests { @@ -286,8 +344,8 @@ func TestForwarding(t *testing.T) { TotalLength: totalLen, Protocol: uint8(header.ICMPv4ProtocolNumber), TTL: test.TTL, - SrcAddr: remoteIPv4Addr1, - DstAddr: remoteIPv4Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) if len(test.options) != 0 { ip.SetHeaderLength(uint8(ipHeaderLength)) @@ -304,15 +362,15 @@ func TestForwarding(t *testing.T) { }) e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + reply, ok := e1.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(ipv4Addr1.Address), - checker.DstAddr(remoteIPv4Addr1), + checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), @@ -325,15 +383,19 @@ func TestForwarding(t *testing.T) { if n := e2.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = e2.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo packet through outgoing NIC") } checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv4Addr1), - checker.DstAddr(remoteIPv4Addr2), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), checker.IPv4Options(test.forwardedOptions), checker.ICMPv4( @@ -347,6 +409,39 @@ func TestForwarding(t *testing.T) { if n := e1.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), boolToInt(test.icmpType == header.ICMPv4ParamProblem); got != want { + t.Errorf("got s.Stats().IP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 0); got != want { + t.Errorf("got s.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) } }) } @@ -1241,7 +1336,6 @@ type fragmentInfo struct { var fragmentationTests = []struct { description string mtu uint32 - gso *stack.GSO transportHeaderLength int payloadSize int wantFragments []fragmentInfo @@ -1249,7 +1343,6 @@ var fragmentationTests = []struct { { description: "No fragmentation", mtu: 1280, - gso: nil, transportHeaderLength: 0, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -1259,7 +1352,6 @@ var fragmentationTests = []struct { { description: "Fragmented", mtu: 1280, - gso: nil, transportHeaderLength: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ @@ -1270,7 +1362,6 @@ var fragmentationTests = []struct { { description: "Fragmented with the minimum mtu", mtu: header.IPv4MinimumMTU, - gso: nil, transportHeaderLength: 0, payloadSize: 100, wantFragments: []fragmentInfo{ @@ -1282,7 +1373,6 @@ var fragmentationTests = []struct { { description: "Fragmented with mtu not a multiple of 8", mtu: header.IPv4MinimumMTU + 1, - gso: nil, transportHeaderLength: 0, payloadSize: 100, wantFragments: []fragmentInfo{ @@ -1294,7 +1384,6 @@ var fragmentationTests = []struct { { description: "No fragmentation with big header", mtu: 2000, - gso: nil, transportHeaderLength: 100, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -1302,20 +1391,8 @@ var fragmentationTests = []struct { }, }, { - description: "Fragmented with gso none", - mtu: 1280, - gso: &stack.GSO{Type: stack.GSONone}, - transportHeaderLength: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1256, more: true}, - {offset: 1256, payloadSize: 144, more: false}, - }, - }, - { description: "Fragmented with big header", mtu: 1280, - gso: nil, transportHeaderLength: 100, payloadSize: 1200, wantFragments: []fragmentInfo{ @@ -1326,7 +1403,6 @@ var fragmentationTests = []struct { { description: "Fragmented with MTU smaller than header", mtu: 300, - gso: nil, transportHeaderLength: 1000, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -1349,13 +1425,13 @@ func TestFragmentationWritePacket(t *testing.T) { r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != nil { - t.Fatalf("r.WritePacket(_, _, _) = %s", err) + t.Fatalf("r.WritePacket(...): %s", err) } if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) @@ -1421,7 +1497,7 @@ func TestFragmentationWritePackets(t *testing.T) { r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -1528,7 +1604,7 @@ func TestFragmentationErrors(t *testing.T) { pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2612,34 +2688,36 @@ func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectOutputDropped int + expectPostroutingDropped int + expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets - 1, }, { - name: "Drop all", + name: "Drop all with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] @@ -2648,16 +2726,32 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %s", err) } }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: nPackets, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { - name: "Drop some", + name: "Drop all with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: 0, + expectPostroutingDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. @@ -2670,10 +2764,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %s", err) } }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 1, + expectPostroutingDropped: 0, + expectWritten: nPackets, + }, { + name: "Drop some with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Postrouting DROP rule that matches only 1 + // of the 3 packets. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %s", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 1, + expectWritten: nPackets, }, } @@ -2687,7 +2804,7 @@ func TestWriteStats(t *testing.T) { writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { return nWritten, err } nWritten++ @@ -2697,7 +2814,7 @@ func TestWriteStats(t *testing.T) { }, { name: "WritePackets", writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) }, }, } @@ -2724,13 +2841,16 @@ func TestWriteStats(t *testing.T) { nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { + t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) + } + if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { + t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) } if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) } }) } @@ -2995,12 +3115,14 @@ func TestCloseLocking(t *testing.T) { nicID1 = 1 nicID2 = 2 - src = tcpip.Address("\x10\x00\x00\x01") - dst = tcpip.Address("\x10\x00\x00\x02") - iterations = 1000 ) + var ( + src = tcptestutil.MustParse4("16.0.0.1") + dst = tcptestutil.MustParse4("16.0.0.2") + ) + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bb9a02ed0..db998e83e 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -66,5 +66,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index a142b76c1..ebb0b73df 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -181,10 +181,13 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe return } + // Keep needed information before trimming header. + p := hdr.TransportProtocol() + dstAddr := hdr.DestinationAddress() + // Skip the IP header, then handle the fragmentation header if there // is one. - pkt.Data().TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6MinimumSize) if p == header.IPv6FragmentHeader { f, ok := pkt.Data().PullUp(header.IPv6FragmentHeaderSize) if !ok { @@ -196,14 +199,14 @@ func (e *endpoint) handleControl(transErr stack.TransportError, pkt *stack.Packe // because they don't have the transport headers. return } + p = fragHdr.TransportProtocol() // Skip fragmentation header and find out the actual protocol // number. - pkt.Data().TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + pkt.Data().DeleteFront(header.IPv6FragmentHeaderSize) } - e.dispatcher.DeliverTransportError(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, transErr, pkt) + e.dispatcher.DeliverTransportError(srcAddr, dstAddr, ProtocolNumber, p, transErr, pkt) } // getLinkAddrOption searches NDP options for a given link address option using @@ -273,7 +276,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if iph.HopLimit() != header.MLDHopLimit { return false } - if !header.IsV6LinkLocalAddress(iph.SourceAddress()) { + if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) { return false } return true @@ -327,11 +330,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6PacketTooBigMinimumSize) networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) if err != nil { networkMTU = 0 } + pkt.Data().DeleteFront(header.ICMPv6PacketTooBigMinimumSize) e.handleControl(&icmpv6PacketTooBigSockError{mtu: networkMTU}, pkt) case header.ICMPv6DstUnreachable: @@ -341,8 +344,9 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r received.invalid.Increment() return } - pkt.Data().TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + code := header.ICMPv6(hdr).Code() + pkt.Data().DeleteFront(header.ICMPv6DstUnreachableMinimumSize) + switch code { case header.ICMPv6NetworkUnreachable: e.handleControl(&icmpv6DestinationNetworkUnreachableSockError{}, pkt) case header.ICMPv6PortUnreachable: @@ -564,7 +568,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.dropped.Increment() return } @@ -704,7 +708,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, @@ -804,7 +808,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r routerAddr := srcAddr // Is the IP Source Address a link-local address? - if !header.IsV6LinkLocalAddress(routerAddr) { + if !header.IsV6LinkLocalUnicastAddress(routerAddr) { // ...No, silently drop the packet. received.invalid.Increment() return @@ -951,6 +955,7 @@ func (*endpoint) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo // icmpReason is a marker interface for IPv6 specific ICMP errors. type icmpReason interface { isICMPReason() + isForwarding() bool } // icmpReasonParameterProblem is an error during processing of extension headers @@ -982,6 +987,9 @@ type icmpReasonParameterProblem struct { } func (*icmpReasonParameterProblem) isICMPReason() {} +func (*icmpReasonParameterProblem) isForwarding() bool { + return false +} // icmpReasonPortUnreachable is an error where the transport protocol has no // listener and no alternative means to inform the sender. @@ -989,12 +997,44 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +func (*icmpReasonPortUnreachable) isForwarding() bool { + return false +} + +// icmpReasonNetUnreachable is an error where no route can be found to the +// network of the final destination. +type icmpReasonNetUnreachable struct{} + +func (*icmpReasonNetUnreachable) isICMPReason() {} + +func (*icmpReasonNetUnreachable) isForwarding() bool { + // If we hit a Network Unreachable error, then we also know we are + // operating as a router. As per RFC 4443 section 3.1: + // + // If the reason for the failure to deliver is lack of a matching + // entry in the forwarding node's routing table, the Code field is + // set to 0 (Network Unreachable). + return true +} + // icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in // transit to its final destination, as per RFC 4443 section 3.3. type icmpReasonHopLimitExceeded struct{} func (*icmpReasonHopLimitExceeded) isICMPReason() {} +func (*icmpReasonHopLimitExceeded) isForwarding() bool { + // If we hit a Hop Limit Exceeded error, then we know we are operating + // as a router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard + // the packet and originate an ICMPv6 Time Exceeded message with Code + // 0 to the source of the packet. This indicates either a routing + // loop or too small an initial Hop Limit value. + return true +} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -1002,6 +1042,10 @@ type icmpReasonReassemblyTimeout struct{} func (*icmpReasonReassemblyTimeout) isICMPReason() {} +func (*icmpReasonReassemblyTimeout) isForwarding() bool { + return false +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip.Error { @@ -1040,15 +1084,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return nil } - // If we hit a Hop Limit Exceeded error, then we know we are operating as a - // router. As per RFC 4443 section 3.3: - // - // If a router receives a packet with a Hop Limit of zero, or if a - // router decrements a packet's Hop Limit to zero, it MUST discard the - // packet and originate an ICMPv6 Time Exceeded message with Code 0 to - // the source of the packet. This indicates either a routing loop or - // too small an initial Hop Limit value. - // // If we are operating as a router, do not use the packet's destination // address as the response's source address as we should not own the // destination address of a packet we are forwarding. @@ -1058,7 +1093,7 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip // packet as "multicast addresses must not be used as source addresses in IPv6 // packets", as per RFC 4291 section 2.7. localAddr := origIPHdrDst - if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + if reason.isForwarding() || isOrigDstMulticast { localAddr = "" } // Even if we were able to receive a packet from some remote, we may not have @@ -1147,6 +1182,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.dstUnreachable + case *icmpReasonNetUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) + counter = sent.dstUnreachable case *icmpReasonHopLimitExceeded: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) @@ -1167,7 +1206,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip PayloadLen: dataRange.Size(), })) if err := route.WritePacket( - nil, /* gso */ stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: route.DefaultTTL(), diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 6a7705ed1..e457be3cf 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -81,7 +81,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (*stubLinkEndpoint) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { return nil } @@ -130,19 +130,19 @@ func (*testInterface) Spoofing() bool { return false } -func (t *testInterface) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { - return t.LinkEndpoint.WritePacket(r.Fields(), gso, protocol, pkt) +func (t *testInterface) WritePacket(r *stack.Route, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { + return t.LinkEndpoint.WritePacket(r.Fields(), protocol, pkt) } -func (t *testInterface) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return t.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol) +func (t *testInterface) WritePackets(r *stack.Route, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return t.LinkEndpoint.WritePackets(r.Fields(), pkts, protocol) } -func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) tcpip.Error { var r stack.RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr - return t.LinkEndpoint.WritePacket(r, gso, protocol, pkt) + return t.LinkEndpoint.WritePacket(r, protocol, pkt) } func (t *testInterface) HandleNeighborProbe(tcpip.NetworkProtocolNumber, tcpip.Address, tcpip.LinkAddress) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c6d9d8f0d..659057fa7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { // Snooping switches MUST manage multicast forwarding state based on MLD // Report and Done messages sent with the unspecified address as the // IPv6 source address. - if header.IsV6LinkLocalAddress(addr) { + if header.IsV6LinkLocalUnicastAddress(addr) { e.mu.mld.sendQueuedReports() } } @@ -410,24 +410,52 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t // // Must only be called when the forwarding status changes. func (e *endpoint) transitionForwarding(forwarding bool) { + allRoutersGroups := [...]tcpip.Address{ + header.IPv6AllRoutersInterfaceLocalMulticastAddress, + header.IPv6AllRoutersLinkLocalMulticastAddress, + header.IPv6AllRoutersSiteLocalMulticastAddress, + } + e.mu.Lock() defer e.mu.Unlock() - if !e.Enabled() { - return - } - if forwarding { - // When transitioning into an IPv6 router, host-only state (NDP discovered - // routers, discovered on-link prefixes, and auto-generated addresses) is - // cleaned up/invalidated and NDP router solicitations are stopped. - e.mu.ndp.stopSolicitingRouters() - e.mu.ndp.cleanupState(true /* hostOnly */) + // As per RFC 4291 section 2.8: + // + // A router is required to recognize all addresses that a host is + // required to recognize, plus the following addresses as identifying + // itself: + // + // o The All-Routers multicast addresses defined in Section 2.7.1. + // + // As per RFC 4291 section 2.7.1, + // + // All Routers Addresses: FF01:0:0:0:0:0:0:2 + // FF02:0:0:0:0:0:0:2 + // FF05:0:0:0:0:0:0:2 + // + // The above multicast addresses identify the group of all IPv6 routers, + // within scope 1 (interface-local), 2 (link-local), or 5 (site-local). + for _, g := range allRoutersGroups { + if err := e.joinGroupLocked(g); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) + } + } } else { - // When transitioning into an IPv6 host, NDP router solicitations are - // started. - e.mu.ndp.startSolicitingRouters() + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } + } } + + e.mu.ndp.forwardingChanged(forwarding) } // Enable implements stack.NetworkEndpoint. @@ -509,17 +537,7 @@ func (e *endpoint) Enable() tcpip.Error { e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyway. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !e.protocol.Forwarding() { - e.mu.ndp.startSolicitingRouters() - } - + e.mu.ndp.startSolicitingRouters() return nil } @@ -570,10 +588,10 @@ func (e *endpoint) disableLocked() { return true }) - e.mu.ndp.cleanupState(false /* hostOnly */) + e.mu.ndp.cleanupState() // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) @@ -632,9 +650,9 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params return nil } -func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { payload := pkt.TransportHeader().View().Size() + pkt.Data().Size() - return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU + return pkt.GSOOptions.Type == stack.GSONone && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each @@ -642,7 +660,7 @@ func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *sta // fragments left to be processed. The IP header must already be present in the // original packet. The transport header protocol number is required to avoid // parsing the IPv6 extension headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { networkHeader := header.IPv6(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are @@ -681,7 +699,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { if err := addIPHeader(r.LocalAddress(), r.RemoteAddress(), pkt, params, nil /* extensionHeaders */); err != nil { return err } @@ -689,7 +707,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -712,10 +730,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } } - return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) + return e.writePacket(r, pkt, params.Protocol, false /* headerIncluded */) } -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { +func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) tcpip.Error { if r.Loop()&stack.PacketLoop != 0 { // If the packet was generated by the stack (not a raw/packet endpoint // where a packet may be written with the header included), then we can @@ -726,6 +744,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } + // Postrouting NAT can only change the source address, and does not alter the + // route or outgoing interface of the packet. + outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) + if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + // iptables is telling us to drop the packet. + e.stats.ip.IPTablesPostroutingDropped.Increment() + return nil + } + stats := e.stats.ip networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { @@ -733,20 +760,20 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return err } - if packetMustBeFragmented(pkt, networkMTU, gso) { - sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if packetMustBeFragmented(pkt, networkMTU) { + sent, remain, err := e.handleFragments(r, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to // WritePackets(). It'll be faster but cost more memory. - return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + return e.nic.WritePacket(r, ProtocolNumber, fragPkt) }) stats.PacketsSent.IncrementBy(uint64(sent)) stats.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } - if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + if err := e.nic.WritePacket(r, ProtocolNumber, pkt); err != nil { stats.OutgoingPacketErrors.Increment() return err } @@ -756,7 +783,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("not implemented") } @@ -776,11 +803,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe stats.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } - if packetMustBeFragmented(pb, networkMTU, gso) { + if packetMustBeFragmented(pb, networkMTU) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { + if _, _, err := e.handleFragments(r, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -797,9 +824,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName) - stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - for pkt := range dropped { + outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) + for pkt := range outputDropped { pkts.Remove(pkt) } @@ -820,14 +847,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe locallyDelivered++ } + // We ignore the list of NAT-ed packets here because Postrouting NAT can only + // change the source address, and does not alter the route or outgoing + // interface of the packet. + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) + for pkt := range postroutingDropped { + pkts.Remove(pkt) + } + // The rest of the packets can be delivered to the NIC as a batch. pktsLen := pkts.Len() - written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber) + written, err := e.nic.WritePackets(r, pkts, ProtocolNumber) stats.PacketsSent.IncrementBy(uint64(written)) stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written)) // Dropped packets aren't errors, so include them in the return value. - return locallyDelivered + written + len(dropped), err + return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err } // WriteHeaderIncludedPacket implements stack.NetworkEndpoint. @@ -863,12 +899,25 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return &tcpip.ErrMalformedHeader{} } - return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) + return e.writePacket(r, pkt, proto, true /* headerIncluded */) } // forwardPacket attempts to forward a packet to its final destination. -func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { h := header.IPv6(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) { + return &ip.ErrLinkLocalSourceAddress{} + } + if header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + return &ip.ErrLinkLocalDestinationAddress{} + } + hopLimit := h.HopLimit() if hopLimit <= 1 { // As per RFC 4443 section 3.3, @@ -878,11 +927,14 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // packet and originate an ICMPv6 Time Exceeded message with Code 0 to // the source of the packet. This indicates either a routing loop or // too small an initial Hop Limit value. - return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + // + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + return &ip.ErrTTLExceeded{} } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -890,8 +942,16 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err + switch err.(type) { + case nil: + case *tcpip.ErrNoRoute, *tcpip.ErrNetworkUnreachable: + // We return the original error rather than the result of returning + // the ICMP packet because the original error is more relevant to + // the caller. + _ = e.protocol.returnError(&icmpReasonNetUnreachable{}, pkt) + return &ip.ErrNoRoute{} + default: + return &ip.ErrOther{Err: err} } defer r.Release() @@ -906,10 +966,13 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // each node that forwards the packet. newHdr.SetHopLimit(hopLimit - 1) - return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: buffer.View(newHdr).ToVectorisedView(), - })) + })); err != nil { + return &ip.ErrOther{Err: err} + } + return nil } // HandlePacket is called by the link layer when new ipv6 packets arrive for @@ -958,7 +1021,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1010,8 +1073,21 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) stats.InvalidDestinationAddressesReceived.Increment() return } - - _ = e.forwardPacket(pkt) + switch err := e.forwardPacket(pkt); err.(type) { + case nil: + return + case *ip.ErrLinkLocalSourceAddress: + e.stats.ip.Forwarding.LinkLocalSource.Increment() + case *ip.ErrLinkLocalDestinationAddress: + e.stats.ip.Forwarding.LinkLocalDestination.Increment() + case *ip.ErrTTLExceeded: + e.stats.ip.Forwarding.ExhaustedTTL.Increment() + case *ip.ErrNoRoute: + e.stats.ip.Forwarding.Unrouteable.Increment() + default: + panic(fmt.Sprintf("unexpected error %s while trying to forward packet: %#v", err, pkt)) + } + e.stats.ip.Forwarding.Errors.Increment() return } @@ -1028,7 +1104,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1571,7 +1647,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { var linkLocalAddr tcpip.Address e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.IsAssigned(false /* allowExpired */) { - if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) { linkLocalAddr = addr return false } @@ -1979,9 +2055,9 @@ func (p *protocol) Forwarding() bool { // Returns true if the forwarding status was updated. func (p *protocol) setForwarding(v bool) bool { if v { - return atomic.SwapUint32(&p.forwarding, 1) == 0 + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) } - return atomic.SwapUint32(&p.forwarding, 0) == 1 + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) } // SetForwarding implements stack.ForwardingNetworkProtocol. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index c206cebeb..4fbe39528 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) { func TestWriteStats(t *testing.T) { const nPackets = 3 tests := []struct { - name string - setup func(*testing.T, *stack.Stack) - allowPackets int - expectSent int - expectDropped int - expectWritten int + name string + setup func(*testing.T, *stack.Stack) + allowPackets int + expectSent int + expectOutputDropped int + expectPostroutingDropped int + expectWritten int }{ { name: "Accept all", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: math.MaxInt32, - expectSent: nPackets, - expectDropped: 0, - expectWritten: nPackets, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: math.MaxInt32, + expectSent: nPackets, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { name: "Accept all with error", // No setup needed, tables accept everything by default. - setup: func(*testing.T, *stack.Stack) {}, - allowPackets: nPackets - 1, - expectSent: nPackets - 1, - expectDropped: 0, - expectWritten: nPackets - 1, + setup: func(*testing.T, *stack.Stack) {}, + allowPackets: nPackets - 1, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 0, + expectWritten: nPackets - 1, }, { - name: "Drop all", + name: "Drop all with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] @@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - allowPackets: math.MaxInt32, - expectSent: 0, - expectDropped: nPackets, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: nPackets, + expectPostroutingDropped: 0, + expectWritten: nPackets, }, { - name: "Drop some", + name: "Drop all with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: 0, + expectOutputDropped: 0, + expectPostroutingDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some with Output chain", setup: func(t *testing.T, stk *stack.Stack) { // Install Output DROP rule that matches only 1 // of the 3 packets. - t.Helper() ipt := stk.IPTables() filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. @@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) { t.Fatalf("failed to replace table: %v", err) } }, - allowPackets: math.MaxInt32, - expectSent: nPackets - 1, - expectDropped: 1, - expectWritten: nPackets, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 1, + expectPostroutingDropped: 0, + expectWritten: nPackets, + }, { + name: "Drop some with Postrouting chain", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Postrouting DROP rule that matches only 1 + // of the 3 packets. + ipt := stk.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Target = &stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + allowPackets: math.MaxInt32, + expectSent: nPackets - 1, + expectOutputDropped: 0, + expectPostroutingDropped: 1, + expectWritten: nPackets, }, } @@ -2542,7 +2584,7 @@ func TestWriteStats(t *testing.T) { writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { nWritten := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + if err := rt.WritePacket(stack.NetworkHeaderParams{}, pkt); err != nil { return nWritten, err } nWritten++ @@ -2552,7 +2594,7 @@ func TestWriteStats(t *testing.T) { }, { name: "WritePackets", writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, tcpip.Error) { - return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + return rt.WritePackets(pkts, stack.NetworkHeaderParams{}) }, }, } @@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) { nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { - t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent) } - if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { - t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped { + t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped) + } + if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped { + t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped) } if nWritten != test.expectWritten { - t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten) } }) } @@ -2694,7 +2739,6 @@ type fragmentInfo struct { var fragmentationTests = []struct { description string mtu uint32 - gso *stack.GSO transHdrLen int payloadSize int wantFragments []fragmentInfo @@ -2702,7 +2746,6 @@ var fragmentationTests = []struct { { description: "No fragmentation", mtu: header.IPv6MinimumMTU, - gso: nil, transHdrLen: 0, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -2712,7 +2755,6 @@ var fragmentationTests = []struct { { description: "Fragmented", mtu: header.IPv6MinimumMTU, - gso: nil, transHdrLen: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ @@ -2723,7 +2765,6 @@ var fragmentationTests = []struct { { description: "Fragmented with mtu not a multiple of 8", mtu: header.IPv6MinimumMTU + 1, - gso: nil, transHdrLen: 0, payloadSize: 2000, wantFragments: []fragmentInfo{ @@ -2734,7 +2775,6 @@ var fragmentationTests = []struct { { description: "No fragmentation with big header", mtu: 2000, - gso: nil, transHdrLen: 100, payloadSize: 1000, wantFragments: []fragmentInfo{ @@ -2742,20 +2782,8 @@ var fragmentationTests = []struct { }, }, { - description: "Fragmented with gso none", - mtu: header.IPv6MinimumMTU, - gso: &stack.GSO{Type: stack.GSONone}, - transHdrLen: 0, - payloadSize: 1400, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 176, more: false}, - }, - }, - { description: "Fragmented with big header", mtu: header.IPv6MinimumMTU, - gso: nil, transHdrLen: 100, payloadSize: 1200, wantFragments: []fragmentInfo{ @@ -2778,7 +2806,7 @@ func TestFragmentationWritePacket(t *testing.T) { source := pkt.Clone() ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2851,7 +2879,7 @@ func TestFragmentationWritePackets(t *testing.T) { r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter - n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2955,7 +2983,7 @@ func TestFragmentationErrors(t *testing.T) { pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) - err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, @@ -2991,36 +3019,94 @@ func TestForwarding(t *testing.T) { } remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) + multicastIPv6Addr := tcpip.Address(net.ParseIP("ff00::").To16()) + linkLocalIPv6Addr := tcpip.Address(net.ParseIP("fe80::").To16()) tests := []struct { - name string - TTL uint8 - expectErrorICMP bool + name string + TTL uint8 + expectErrorICMP bool + expectPacketForwarded bool + countUnrouteablePackets uint64 + sourceAddr tcpip.Address + destAddr tcpip.Address + icmpType header.ICMPv6Type + icmpCode header.ICMPv6Code + expectPacketUnrouteableError bool + expectLinkLocalSourceError bool + expectLinkLocalDestError bool }{ { name: "TTL of zero", TTL: 0, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, }, { name: "TTL of one", TTL: 1, expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + icmpType: header.ICMPv6TimeExceeded, + icmpCode: header.ICMPv6HopLimitExceeded, + }, + { + name: "TTL of two", + TTL: 2, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, + }, + { + name: "TTL of three", + TTL: 3, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, }, { - name: "TTL of two", - TTL: 2, - expectErrorICMP: false, + name: "Max TTL", + TTL: math.MaxUint8, + expectPacketForwarded: true, + sourceAddr: remoteIPv6Addr1, + destAddr: remoteIPv6Addr2, }, { - name: "TTL of three", - TTL: 3, - expectErrorICMP: false, + name: "Network unreachable", + TTL: 2, + expectErrorICMP: true, + sourceAddr: remoteIPv6Addr1, + destAddr: unreachableIPv6Addr, + icmpType: header.ICMPv6DstUnreachable, + icmpCode: header.ICMPv6NetworkUnreachable, + expectPacketUnrouteableError: true, }, { - name: "Max TTL", - TTL: math.MaxUint8, - expectErrorICMP: false, + name: "Multicast destination", + TTL: 2, + countUnrouteablePackets: 1, + sourceAddr: remoteIPv6Addr1, + destAddr: multicastIPv6Addr, + expectPacketUnrouteableError: true, + }, + { + name: "Link local destination", + TTL: 2, + sourceAddr: remoteIPv6Addr1, + destAddr: linkLocalIPv6Addr, + expectLinkLocalDestError: true, + }, + { + name: "Link local source", + TTL: 2, + sourceAddr: linkLocalIPv6Addr, + destAddr: remoteIPv6Addr2, + expectLinkLocalSourceError: true, }, } @@ -3073,35 +3159,35 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(0) icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmp, - Src: remoteIPv6Addr1, - Dst: remoteIPv6Addr2, + Src: test.sourceAddr, + Dst: test.destAddr, })) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: header.ICMPv6MinimumSize, TransportProtocol: header.ICMPv6ProtocolNumber, HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + SrcAddr: test.sourceAddr, + DstAddr: test.destAddr, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) e1.InjectInbound(ProtocolNumber, requestPkt) + reply, ok := e1.Read() if test.expectErrorICMP { - reply, ok := e1.Read() if !ok { - t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), checker.SrcAddr(ipv6Addr1.Address), - checker.DstAddr(remoteIPv6Addr1), + checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6TimeExceeded), - checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), + checker.ICMPv6Type(test.icmpType), + checker.ICMPv6Code(test.icmpCode), checker.ICMPv6Payload([]byte(hdr.View())), ), ) @@ -3109,15 +3195,19 @@ func TestForwarding(t *testing.T) { if n := e2.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } - } else { - reply, ok := e2.Read() + } else if ok { + t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) + } + + reply, ok = e2.Read() + if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(remoteIPv6Addr1), - checker.DstAddr(remoteIPv6Addr2), + checker.SrcAddr(test.sourceAddr), + checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6EchoRequest), @@ -3129,6 +3219,35 @@ func TestForwarding(t *testing.T) { if n := e1.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } + } else if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) + } + + boolToInt := func(val bool) uint64 { + if val { + return 1 + } + return 0 + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalSource.Value(), boolToInt(test.expectLinkLocalSourceError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalSource.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.LinkLocalDestination.Value(), boolToInt(test.expectLinkLocalDestError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.LinkLocalDestination.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.ExhaustedTTL.Value(), boolToInt(test.TTL <= 1); got != want { + t.Errorf("got rt.Stats().IP.Forwarding.ExhaustedTTL.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Unrouteable.Value(), boolToInt(test.expectPacketUnrouteableError); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Unrouteable.Value() = %d, want = %d", got, want) + } + + if got, want := s.Stats().IP.Forwarding.Errors.Value(), boolToInt(!test.expectPacketForwarded); got != want { + t.Errorf("got s.Stats().IP.Forwarding.Errors.Value() = %d, want = %d", got, want) } }) } diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index dd153466d..bc1af193c 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) // // Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { - _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2710 section 5 page 10, + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + // + // MLD messages are never sent for multicast addresses whose scope is 0 + // (reserved) or 1 (node-local). + if groupAddress == header.IPv6AllNodesMulticastAddress { + return false + } + + scope := header.V6MulticastScope(groupAddress) + return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope +} + // init sets up an mldState struct, and is required to be called before using // a new mldState. // @@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: mld, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv6AllNodesMulticastAddress, }) } @@ -259,7 +277,7 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp }, extensionHeaders); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), ProtocolNumber, pkt); err != nil { sentStats.dropped.Increment() return false, err } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 85a8f9944..71d1c3e28 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -27,15 +27,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" -) - -const ( - linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var ( + linkLocalAddr = testutil.MustParse6("fe80::1") + globalAddr = testutil.MustParse6("a80::1") + globalMulticastAddr = testutil.MustParse6("ff05:100::2") + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) ) @@ -93,7 +92,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { if p, ok := e.Read(); !ok { t.Fatal("expected a done message to be sent") } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } } @@ -354,10 +353,8 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType header.ICMPv6Type, ho } func TestMLDPacketValidation(t *testing.T) { - const ( - nicID = 1 - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ) + const nicID = 1 + linkLocalAddr2 := testutil.MustParse6("fe80::2") tests := []struct { name string @@ -464,3 +461,141 @@ func TestMLDPacketValidation(t *testing.T) { }) } } + +func TestMLDSkipProtocol(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + group tcpip.Address + expectReport bool + }{ + { + name: "Reserverd0", + group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Interface Local", + group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Link Local", + group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Realm Local", + group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Admin Local", + group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Site Local", + group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(6)", + group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(7)", + group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Organization Local", + group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(9)", + group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(A)", + group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(B)", + group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(C)", + group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(D)", + group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Global", + group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "ReservedF", + group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil { + t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err) + } + if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group) + } + + if !test.expectReport { + if p, ok := e.Read(); ok { + t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p) + } + + return + } + + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 536493f87..b29fed347 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -48,7 +48,7 @@ const ( // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - defaultHandleRAs = true + defaultHandleRAs = HandlingRAsEnabledWhenForwardingDisabled // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router @@ -301,10 +301,60 @@ type NDPDispatcher interface { OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } +var _ fmt.Stringer = HandleRAsConfiguration(0) + +// HandleRAsConfiguration enumerates when RAs may be handled. +type HandleRAsConfiguration int + +const ( + // HandlingRAsDisabled indicates that Router Advertisements will not be + // handled. + HandlingRAsDisabled HandleRAsConfiguration = iota + + // HandlingRAsEnabledWhenForwardingDisabled indicates that router + // advertisements will only be handled when forwarding is disabled. + HandlingRAsEnabledWhenForwardingDisabled + + // HandlingRAsAlwaysEnabled indicates that Router Advertisements will always + // be handled, even when forwarding is enabled. + HandlingRAsAlwaysEnabled +) + +// String implements fmt.Stringer. +func (c HandleRAsConfiguration) String() string { + switch c { + case HandlingRAsDisabled: + return "HandlingRAsDisabled" + case HandlingRAsEnabledWhenForwardingDisabled: + return "HandlingRAsEnabledWhenForwardingDisabled" + case HandlingRAsAlwaysEnabled: + return "HandlingRAsAlwaysEnabled" + default: + return fmt.Sprintf("HandleRAsConfiguration(%d)", c) + } +} + +// enabled returns true iff Router Advertisements may be handled given the +// specified forwarding status. +func (c HandleRAsConfiguration) enabled(forwarding bool) bool { + switch c { + case HandlingRAsDisabled: + return false + case HandlingRAsEnabledWhenForwardingDisabled: + return !forwarding + case HandlingRAsAlwaysEnabled: + return true + default: + panic(fmt.Sprintf("unhandled HandleRAsConfiguration = %d", c)) + } +} + // NDPConfigurations is the NDP configurations for the netstack. type NDPConfigurations struct { // The number of Router Solicitation messages to send when the IPv6 endpoint // becomes enabled. + // + // Ignored unless configured to handle Router Advertisements. MaxRtrSolicitations uint8 // The amount of time between transmitting Router Solicitation messages. @@ -318,8 +368,9 @@ type NDPConfigurations struct { // Must be greater than or equal to 0s. MaxRtrSolicitationDelay time.Duration - // HandleRAs determines whether or not Router Advertisements are processed. - HandleRAs bool + // HandleRAs is the configuration for when Router Advertisements should be + // handled. + HandleRAs HandleRAsConfiguration // DiscoverDefaultRouters determines whether or not default routers are // discovered from Router Advertisements, as per RFC 4861 section 6. This @@ -654,7 +705,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs || ndp.ep.protocol.Forwarding() { + if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -737,7 +789,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { prefix := opt.Subnet() // Is the prefix a link-local? - if header.IsV6LinkLocalAddress(prefix.ID()) { + if header.IsV6LinkLocalUnicastAddress(prefix.ID()) { // ...Yes, skip as per RFC 4861 section 6.3.4, // and RFC 4862 section 5.5.3.b (for SLAAC). continue @@ -1609,44 +1661,16 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs map[t delete(tempAddrs, tempAddr) } -// removeSLAACAddresses removes all SLAAC addresses. -// -// If keepLinkLocal is false, the SLAAC generated link-local address is removed. -// -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) removeSLAACAddresses(keepLinkLocal bool) { - linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - var linkLocalPrefixes int - for prefix, state := range ndp.slaacPrefixes { - // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them if we are cleaning up - // host-only state. - if keepLinkLocal && prefix == linkLocalSubnet { - linkLocalPrefixes++ - continue - } - - ndp.invalidateSLAACPrefix(prefix, state) - } - - if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) - } -} - // cleanupState cleans up ndp's state. // -// If hostOnly is true, then only host-specific state is cleaned up. -// // This function invalidates all discovered on-link prefixes, discovered // routers, and auto-generated addresses. // -// If hostOnly is true, then the link-local auto-generated address aren't -// invalidated as routers are also expected to generate a link-local address. -// // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { - ndp.removeSLAACAddresses(hostOnly /* keepLinkLocal */) +func (ndp *ndpState) cleanupState() { + for prefix, state := range ndp.slaacPrefixes { + ndp.invalidateSLAACPrefix(prefix, state) + } for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) @@ -1670,6 +1694,10 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // startSolicitingRouters starts soliciting routers, as per RFC 4861 section // 6.3.7. If routers are already being solicited, this function does nothing. // +// If ndp is not configured to handle Router Advertisements, routers will not +// be solicited as there is no point soliciting routers if we don't handle their +// advertisements. +// // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { if ndp.rtrSolicitTimer.timer != nil { @@ -1682,6 +1710,10 @@ func (ndp *ndpState) startSolicitingRouters() { return } + if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + return + } + // Calculate the random delay before sending our first RS, as per RFC // 4861 section 6.3.7. var delay time.Duration @@ -1703,7 +1735,7 @@ func (ndp *ndpState) startSolicitingRouters() { // the unspecified address if no address is assigned // to the sending interface. localAddr := header.IPv6Any - if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil { localAddr = addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() } @@ -1730,7 +1762,7 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpData, Src: localAddr, - Dst: header.IPv6AllRoutersMulticastAddress, + Dst: header.IPv6AllRoutersLinkLocalMulticastAddress, })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1739,14 +1771,14 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), ProtocolNumber, pkt); err != nil { sent.dropped.Increment() // Don't send any more messages if we had an error. remaining = 0 @@ -1774,6 +1806,32 @@ func (ndp *ndpState) startSolicitingRouters() { } } +// forwardingChanged handles a change in forwarding configuration. +// +// If transitioning to a host, router solicitation will be started. Otherwise, +// router solicitation will be stopped if NDP is not configured to handle RAs +// as a router. +// +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) forwardingChanged(forwarding bool) { + if forwarding { + if ndp.configs.HandleRAs.enabled(forwarding) { + return + } + + ndp.stopSolicitingRouters() + return + } + + // Solicit routers when transitioning to a host. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if ndp.ep.Enabled() { + ndp.startSolicitingRouters() + } +} + // stopSolicitingRouters stops soliciting routers. If routers are not currently // being solicited, this function does nothing. // @@ -1839,7 +1897,7 @@ func (e *endpoint) sendNDPNS(srcAddr, dstAddr, targetAddr tcpip.Address, remoteL } sent := e.stats.icmp.packetsSent - err := e.nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) + err := e.nic.WritePacketToRemote(remoteLinkAddr, ProtocolNumber, pkt) if err != nil { sent.dropped.Increment() } else { diff --git a/pkg/tcpip/network/ipv6/stats.go b/pkg/tcpip/network/ipv6/stats.go index c2758352f..2f18f60e8 100644 --- a/pkg/tcpip/network/ipv6/stats.go +++ b/pkg/tcpip/network/ipv6/stats.go @@ -29,6 +29,10 @@ type Stats struct { // ICMP holds ICMPv6 statistics. ICMP tcpip.ICMPv6Stats + + // UnhandledRouterAdvertisements is the number of Router Advertisements that + // were observed but not handled. + UnhandledRouterAdvertisements *tcpip.StatCounter } // IsNetworkEndpointStats implements stack.NetworkEndpointStats. diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index ecd5003a7..1b96b1fb8 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -30,22 +30,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - stackIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") defaultIPv4PrefixLength = 24 - linkLocalIPv6Addr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalIPv6Addr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - - ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") - ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") - ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") - ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") - ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") igmpMembershipQuery = uint8(header.IGMPMembershipQuery) igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) @@ -59,6 +50,19 @@ const ( ) var ( + stackIPv4Addr = testutil.MustParse4("10.0.0.1") + linkLocalIPv6Addr1 = testutil.MustParse6("fe80::1") + linkLocalIPv6Addr2 = testutil.MustParse6("fe80::2") + + ipv4MulticastAddr1 = testutil.MustParse4("224.0.0.3") + ipv4MulticastAddr2 = testutil.MustParse4("224.0.0.4") + ipv4MulticastAddr3 = testutil.MustParse4("224.0.0.5") + ipv6MulticastAddr1 = testutil.MustParse6("ff02::3") + ipv6MulticastAddr2 = testutil.MustParse6("ff02::4") + ipv6MulticastAddr3 = testutil.MustParse6("ff02::5") +) + +var ( // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the // NIC will wait before sending an unsolicited report after joining a // multicast group, in deciseconds. @@ -194,7 +198,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) } // Should not send any more packets. @@ -606,7 +610,7 @@ func TestMGPLeaveGroup(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, checkInitialGroups: checkInitialIPv6Groups, }, @@ -1014,7 +1018,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) }, getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { t.Helper() diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 210262703..b7f6d52ae 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -21,6 +21,7 @@ go_test( library = ":ports", deps = [ "//pkg/tcpip", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 678199371..b5b013b64 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -17,6 +17,7 @@ package ports import ( + "math" "math/rand" "sync/atomic" @@ -24,7 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -const anyIPAddress tcpip.Address = "" +const ( + firstEphemeral = 16000 + anyIPAddress tcpip.Address = "" +) // Reservation describes a port reservation. type Reservation struct { @@ -220,10 +224,8 @@ type PortManager struct { func NewPortManager() *PortManager { return &PortManager{ allocatedPorts: make(map[portDescriptor]addrToDevice), - // Match Linux's default ephemeral range. See: - // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 - firstEphemeral: 32768, - numEphemeral: 28232, + firstEphemeral: firstEphemeral, + numEphemeral: math.MaxUint16 - firstEphemeral + 1, } } @@ -242,13 +244,13 @@ func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint16(rand.Int31n(int32(numEphemeral))) + offset := uint32(rand.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } // portHint atomically reads and returns the pm.hint value. -func (pm *PortManager) portHint() uint16 { - return uint16(atomic.LoadUint32(&pm.hint)) +func (pm *PortManager) portHint() uint32 { + return atomic.LoadUint32(&pm.hint) } // incPortHint atomically increments pm.hint by 1. @@ -260,7 +262,7 @@ func (pm *PortManager) incPortHint() { // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPortStable(offset uint32, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral @@ -277,9 +279,9 @@ func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTeste // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { - for i := uint16(0); i < count; i++ { - port = first + (offset+i)%count +func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { + for i := uint32(0); i < uint32(count); i++ { + port := uint16(uint32(first) + (offset+i)%uint32(count)) ok, err := testPort(port) if err != nil { return 0, err diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 0f43dc8f8..6c4fb8c68 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -15,19 +15,23 @@ package ports import ( + "math" "math/rand" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 +) - fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") - fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") +var ( + fakeIPAddress = testutil.MustParse4("8.8.8.8") + fakeIPAddress1 = testutil.MustParse4("8.8.8.9") ) type portReserveTestAction struct { @@ -479,7 +483,7 @@ func TestPickEphemeralPortStable(t *testing.T) { if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) + portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) port, err := pm.PickEphemeralPortStable(portOffset, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) @@ -490,3 +494,29 @@ func TestPickEphemeralPortStable(t *testing.T) { }) } } + +// TestOverflow addresses b/183593432, wherein an overflowing uint16 causes a +// port allocation failure. +func TestOverflow(t *testing.T) { + // Use a small range and start at offsets that will cause an overflow. + count := uint16(50) + for offset := uint32(math.MaxUint16 - count); offset < math.MaxUint16; offset++ { + reservedPorts := make(map[uint16]struct{}) + // Ensure we can reserve everything in the allowed range. + for i := uint16(0); i < count; i++ { + port, err := pickEphemeralPort(offset, firstEphemeral, count, func(port uint16) (bool, tcpip.Error) { + if _, ok := reservedPorts[port]; !ok { + reservedPorts[port] = struct{}{} + return true, nil + } + return false, nil + }) + if err != nil { + t.Fatalf("port picking failed at iteration %d, for offset %d, len(reserved): %+v", i, offset, len(reservedPorts)) + } + if port < firstEphemeral || port > firstEphemeral+count { + t.Fatalf("reserved port %d, which is not in range [%d, %d]", port, firstEphemeral, firstEphemeral+count-1) + } + } + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index dc37e61a4..a6c877158 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -58,6 +58,9 @@ type SocketOptionsHandler interface { // changed. The handler is invoked with the new value for the socket send // buffer size. It also returns the newly set value. OnSetSendBufferSize(v int64) (newSz int64) + + // OnSetReceiveBufferSize is invoked to set the SO_RCVBUFSIZE. + OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -99,6 +102,11 @@ func (*DefaultSocketOptionsHandler) OnSetSendBufferSize(v int64) (newSz int64) { return v } +// OnSetReceiveBufferSize implements SocketOptionsHandler.OnSetReceiveBufferSize. +func (*DefaultSocketOptionsHandler) OnSetReceiveBufferSize(v, oldSz int64) (newSz int64) { + return v +} + // StackHandler holds methods to access the stack options. These must be // implemented by the stack. type StackHandler interface { @@ -207,6 +215,14 @@ type SocketOptions struct { // sendBufferSize determines the send buffer size for this socket. sendBufferSize int64 + // getReceiveBufferLimits provides the handler to get the min, default and + // max size for receive buffer. It is initialized at the creation time and + // will not change. + getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` + + // receiveBufferSize determines the receive buffer size for this socket. + receiveBufferSize int64 + // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -217,10 +233,11 @@ type SocketOptions struct { // InitHandler initializes the handler. This must be called before using the // socket options utility. -func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits) { +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler, stack StackHandler, getSendBufferLimits GetSendBufferLimits, getReceiveBufferLimits GetReceiveBufferLimits) { so.handler = handler so.stackHandler = stack so.getSendBufferLimits = getSendBufferLimits + so.getReceiveBufferLimits = getReceiveBufferLimits } func storeAtomicBool(addr *uint32, v bool) { @@ -632,3 +649,42 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { newSz := so.handler.OnSetSendBufferSize(v) atomic.StoreInt64(&so.sendBufferSize, newSz) } + +// GetReceiveBufferSize gets value for SO_RCVBUF option. +func (so *SocketOptions) GetReceiveBufferSize() int64 { + return atomic.LoadInt64(&so.receiveBufferSize) +} + +// SetReceiveBufferSize sets value for SO_RCVBUF option. +func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { + if !notify { + atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + return + } + + // Make sure the send buffer size is within the min and max + // allowed. + v := receiveBufferSize + ss := so.getReceiveBufferLimits(so.stackHandler) + min := int64(ss.Min) + max := int64(ss.Max) + // Validate the send buffer size with min and max values. + if v > max { + v = max + } + + // Multiply it by factor of 2. + if v < math.MaxInt32/PacketOverheadFactor { + v *= PacketOverheadFactor + if v < min { + v = min + } + } else { + v = math.MaxInt32 + } + + oldSz := atomic.LoadInt64(&so.receiveBufferSize) + // Notify endpoint about change in buffer size. + newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) + atomic.StoreInt64(&so.receiveBufferSize, newSz) +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 49362333a..2bd6a67f5 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -45,6 +45,7 @@ go_library( "addressable_endpoint_state.go", "conntrack.go", "headertype_string.go", + "hook_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -66,6 +67,7 @@ go_library( "stack.go", "stack_global_state.go", "stack_options.go", + "tcp.go", "transport_demuxer.go", "tuple_list.go", ], @@ -115,6 +117,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -139,6 +142,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/testutil", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 3f083928f..5720e7543 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -16,6 +16,7 @@ package stack import ( "encoding/binary" + "fmt" "sync" "time" @@ -29,7 +30,7 @@ import ( // The connection is created for a packet if it does not exist. Every // connection contains two tuples (original and reply). The tuples are // manipulated if there is a matching NAT rule. The packet is modified by -// looking at the tuples in the Prerouting and Output hooks. +// looking at the tuples in each hook. // // Currently, only TCP tracking is supported. @@ -46,12 +47,14 @@ const ( ) // Manipulation type for the connection. +// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and +// DNAT at the same time. type manipType int const ( manipNone manipType = iota - manipDstPrerouting - manipDstOutput + manipSource + manipDestination ) // tuple holds a connection's identifying and manipulating data in one @@ -108,6 +111,7 @@ type conn struct { reply tuple // manip indicates if the packet should be manipulated. It is immutable. + // TODO(gvisor.dev/issue/5696): Support updating manipulation type. manip manipType // tcbHook indicates if the packet is inbound or outbound to @@ -124,6 +128,18 @@ type conn struct { lastUsed time.Time `state:".(unixTime)"` } +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn +} + // timedOut returns whether the connection timed out based on its state. func (cn *conn) timedOut(now time.Time) bool { const establishedTimeout = 5 * 24 * time.Hour @@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { }, nil } -// newConn creates new connection. -func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { - conn := conn{ - manip: manip, - tcbHook: hook, - lastUsed: time.Now(), - } - conn.original = tuple{conn: &conn, tupleID: orig} - conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} - return &conn -} - func (ct *ConnTrack) init() { ct.mu.Lock() defer ct.mu.Unlock() @@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1 return nil } - // Create a new connection and change the port as per the iptables - // rule. This tuple will be used to manipulate the packet in - // handlePacket. replyTID := tid.reply() replyTID.srcAddr = address replyTID.srcPort = port - var manip manipType - switch hook { - case Prerouting: - manip = manipDstPrerouting - case Output: - manip = manipDstOutput + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil } - conn := newConn(tid, replyTID, manip, hook) + conn = newConn(tid, replyTID, manipDestination, hook) + ct.insertConn(conn) + return conn +} + +func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Input && hook != Postrouting { + return nil + } + + replyTID := tid.reply() + replyTID.dstAddr = address + replyTID.dstPort = port + + conn, _ := ct.connForTID(tid) + if conn != nil { + // The connection is already tracked. + // TODO(gvisor.dev/issue/5696): Support updating an existing connection. + return nil + } + conn = newConn(tid, replyTID, manipSource, hook) ct.insertConn(conn) return conn } @@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) { // Now that we hold the locks, ensure the tuple hasn't been inserted by // another thread. + // TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too? alreadyInserted := false for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { if other.tupleID == conn.original.tupleID { @@ -343,95 +369,17 @@ func (ct *ConnTrack) insertConn(conn *conn) { } } -// handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. -func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For prerouting redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. - switch dir { - case dirOriginal: - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - case dirReply: - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated - // on inbound packets, so we don't recalculate them. However, we should - // support cases when they are validated, e.g. when we can't offload - // receive checksumming. - - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - -// handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { - // If this is a noop entry, don't do anything. - if conn.manip == manipNone { - return - } - - netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - - // For output redirection, packets going in the original direction - // have their destinations modified and replies have their sources - // modified. For prerouting redirection, we only reach this point - // when replying, so packet sources are modified. - if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.reply.srcPort - tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.reply.srcAddr) - } else { - port := conn.original.dstPort - tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.original.dstAddr) - } - - // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - if gso != nil && gso.NeedsCsum { - tcpHeader.SetChecksum(xsum) - } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) - } - - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } -} - // handlePacket will manipulate the port and address of the packet if the // connection exists. Returns whether, after the packet traverses the tables, // it should create a new entry in the table. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if pkt.NatDone { return false } - if hook != Prerouting && hook != Output { + switch hook { + case Prerouting, Input, Output, Postrouting: + default: return false } @@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } conn, dir := ct.connFor(pkt) - // Connection or Rule not found for the packet. + // Connection not found for the packet. if conn == nil { - return true + // If this is the last hook in the data path for this packet (Input if + // incoming, Postrouting if outgoing), indicate that a connection should be + // inserted by the end of this hook. + return hook == Input || hook == Postrouting } + netHeader := pkt.Network() tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { return false } + // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be + // validated if checksum offloading is off. It may require IP defrag if the + // packets are fragmented. + + switch hook { + case Prerouting, Output: + if conn.manip == manipDestination { + switch dir { + case dirOriginal: + tcpHeader.SetDestinationPort(conn.reply.srcPort) + netHeader.SetDestinationAddress(conn.reply.srcAddr) + case dirReply: + tcpHeader.SetSourcePort(conn.original.dstPort) + netHeader.SetSourceAddress(conn.original.dstAddr) + } + pkt.NatDone = true + } + case Input, Postrouting: + if conn.manip == manipSource { + switch dir { + case dirOriginal: + tcpHeader.SetSourcePort(conn.reply.dstPort) + netHeader.SetSourceAddress(conn.reply.dstAddr) + case dirReply: + tcpHeader.SetDestinationPort(conn.original.srcPort) + netHeader.SetDestinationAddress(conn.original.srcAddr) + } + pkt.NatDone = true + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + if !pkt.NatDone { + return false + } + switch hook { - case Prerouting: - handlePacketPrerouting(pkt, conn, dir) - case Output: - handlePacketOutput(pkt, conn, gso, r, dir) + case Prerouting, Input: + case Output, Postrouting: + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + length := uint16(len(tcpHeader) + pkt.Data().Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + default: + panic(fmt.Sprintf("unrecognized hook = %s", hook)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) } - pkt.NatDone = true // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle @@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ if conn == nil { // Not a tracked connection. return "", 0, &tcpip.ErrNotConnected{} - } else if conn.manip == manipNone { - // Unmanipulated connection. + } else if conn.manip != manipDestination { + // Unmanipulated destination. return "", 0, &tcpip.ErrInvalidOptionValue{} } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 16ee75bc4..7d3725681 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: vv.ToView().ToVectorisedView(), }) - // TODO(b/143425874) Decrease the TTL field in forwarded packets. + // TODO(gvisor.dev/issue/1085) Decrease the TTL field in forwarded packets. _ = r.WriteHeaderIncludedPacket(pkt) } @@ -117,7 +117,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) @@ -125,11 +125,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress()[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -139,7 +139,7 @@ func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *Packet return &tcpip.ErrMalformedHeader{} } - return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) + return f.nic.WritePacket(r, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -264,6 +264,8 @@ type fwdTestPacketInfo struct { Pkt *PacketBuffer } +var _ LinkEndpoint = (*fwdTestLinkEndpoint)(nil) + type fwdTestLinkEndpoint struct { dispatcher NetworkDispatcher mtu uint32 @@ -306,11 +308,6 @@ func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { return caps | CapabilityResolutionRequired } -// GSOMaxSize returns the maximum GSO packet size. -func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { - return 1 << 15 -} - // MaxHeaderLength returns the maximum size of the link layer header. Given it // doesn't have a header, it just returns 0. func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { @@ -322,7 +319,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -338,10 +335,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, gso *GSO, protocol tcpip.N } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r RouteInfo, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, pkt) + e.WritePacket(r, protocol, pkt) n++ } diff --git a/pkg/tcpip/stack/hook_string.go b/pkg/tcpip/stack/hook_string.go new file mode 100644 index 000000000..3dc8a7b02 --- /dev/null +++ b/pkg/tcpip/stack/hook_string.go @@ -0,0 +1,41 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type Hook ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Prerouting-0] + _ = x[Input-1] + _ = x[Forward-2] + _ = x[Output-3] + _ = x[Postrouting-4] + _ = x[NumHooks-5] +} + +const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks" + +var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47} + +func (i Hook) String() string { + if i >= Hook(len(_Hook_index)-1) { + return "Hook(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Hook_name[_Hook_index[i]:_Hook_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 52890f6eb..e2894c548 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -175,9 +175,10 @@ func DefaultTables() *IPTables { }, }, priorities: [NumHooks][]TableID{ - Prerouting: {MangleID, NATID}, - Input: {NATID, FilterID}, - Output: {MangleID, NATID, FilterID}, + Prerouting: {MangleID, NATID}, + Input: {NATID, FilterID}, + Output: {MangleID, NATID, FilterID}, + Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -266,12 +267,12 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from // which address can be gathered. Currently, address is only needed for // prerouting. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -285,7 +286,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer // Packets are manipulated only if connection and matching // NAT rule exists. - shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, r) // Go through each table containing the hook. priorities := it.priorities[hook] @@ -302,7 +303,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -313,7 +314,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -385,10 +386,10 @@ func (it *IPTables) startReaper(interval time.Duration) { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok { + if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -408,11 +409,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -429,7 +430,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -455,7 +456,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -478,7 +479,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) + return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 0e8b90c9b..2812c89aa 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -103,7 +103,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -174,7 +174,85 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // packet of the connection comes here. Other packets will be // manipulated in connection tracking. if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { - ct.handlePacket(pkt, hook, gso, r) + ct.handlePacket(pkt, hook, r) + } + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} + +// SNATTarget modifies the source port/IP in the outgoing packets. +type SNATTarget struct { + Addr tcpip.Address + Port uint16 + + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// Action implements Target.Action. +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if st.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + st.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } + + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { + return RuleDrop, 0 + } + + switch hook { + case Postrouting, Input: + case Prerouting, Output, Forward: + panic(fmt.Sprintf("%s not supported", hook)) + default: + panic(fmt.Sprintf("%s unrecognized", hook)) + } + + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetChecksum(0) + udpHeader.SetSourcePort(st.Port) + netHeader := pkt.Network() + netHeader.SetSourceAddress(st.Addr) + + // Only calculate the checksum if offloading isn't supported. + if r.RequiresTXTransportChecksum() { + length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) + } + + // After modification, IPv4 packets need a valid checksum. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + } + pkt.NatDone = true + case header.TCPProtocolNumber: + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index b0d84befb..4631ab93f 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -345,5 +345,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 14124ae66..c585b81b2 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -33,15 +33,19 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) +var ( + addr1 = testutil.MustParse6("a00::1") + addr2 = testutil.MustParse6("a00::2") + addr3 = testutil.MustParse6("a00::3") +) + const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") @@ -1142,57 +1146,198 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on }) } -// TestNoRouterDiscovery tests that router discovery will not be performed if -// configured not to. -func TestNoRouterDiscovery(t *testing.T) { - // Being configured to discover routers means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // router discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) +func TestDynamicConfigurationsDisabled(t *testing.T) { + const ( + nicID = 1 + maxRtrSolicitDelay = time.Second + ) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + prefix := tcpip.AddressWithPrefix{ + Address: testutil.MustParse6("102:304:506:708::"), + PrefixLen: 64, + } - // Rx an RA with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router when configured not to") - default: + tests := []struct { + name string + config func(bool) ipv6.NDPConfigurations + ra *stack.PacketBuffer + }{ + { + name: "No Router Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} + }, + ra: raBuf(llAddr2, 1000), + }, + { + name: "No Prefix Discovery", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{DiscoverOnLinkPrefixes: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0), + }, + { + name: "No Autogenerate Addresses", + config: func(enable bool) ipv6.NDPConfigurations { + return ipv6.NDPConfigurations{AutoGenGlobalAddresses: enable} + }, + ra: raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Being configured to discover routers/prefixes or auto-generate + // addresses means RAs must be handled, and router/prefix discovery or + // SLAAC must be enabled. + // + // This tests all possible combinations of the configurations where + // router/prefix discovery or SLAAC are disabled. + for i := 0; i < 7; i++ { + handle := ipv6.HandlingRAsDisabled + if i&1 != 0 { + handle = ipv6.HandlingRAsEnabledWhenForwardingDisabled + } + enable := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + ndpConfigs := test.config(enable) + ndpConfigs.HandleRAs = handle + ndpConfigs.MaxRtrSolicitations = 1 + ndpConfigs.RtrSolicitationInterval = maxRtrSolicitDelay + ndpConfigs.MaxRtrSolicitationDelay = maxRtrSolicitDelay + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + Clock: clock, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, + }) + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + handleRAsDisabled := handle == ipv6.HandlingRAsDisabled || forwarding + ep, err := s.GetNetworkEndpoint(nicID, ipv6.ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, ipv6.ProtocolNumber, err) + } + stats := ep.Stats() + v6Stats, ok := stats.(*ipv6.Stats) + if !ok { + t.Fatalf("got v6Stats = %T, expected = %T", stats, v6Stats) + } + + // Make sure that when handling RAs are enabled, we solicit routers. + clock.Advance(maxRtrSolicitDelay) + if got, want := v6Stats.ICMP.PacketsSent.RouterSolicit.Value(), boolToUint64(!handleRAsDisabled); got != want { + t.Errorf("got v6Stats.ICMP.PacketsSent.RouterSolicit.Value() = %d, want = %d", got, want) + } + if handleRAsDisabled { + if p, ok := e.Read(); ok { + t.Errorf("unexpectedly got a packet = %#v", p) + } + } else if p, ok := e.Read(); !ok { + t.Error("expected router solicitation packet") + } else if p.Proto != header.IPv6ProtocolNumber { + t.Errorf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } else { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(nil)), + ) + } + + // Make sure we do not discover any routers or prefixes, or perform + // SLAAC on reception of an RA. + e.InjectInbound(header.IPv6ProtocolNumber, test.ra.Clone()) + // Make sure that the unhandled RA stat is only incremented when + // handling RAs is disabled. + if got, want := v6Stats.UnhandledRouterAdvertisements.Value(), boolToUint64(handleRAsDisabled); got != want { + t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpectedly discovered a prefix when configured not to: %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpectedly auto-generated an address when configured not to: %#v", e) + default: + } + }) } }) } } +func boolToUint64(v bool) uint64 { + if v { + return 1 + } + return 0 +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // discovered flag set to discovered. func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) } +func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { + tests := [...]struct { + name string + handleRAs ipv6.HandleRAsConfiguration + forwarding bool + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding disabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: false, + }, + { + name: "Always Handle RAs with forwarding enabled", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + forwarding: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f(t, test.handleRAs, test.forwarding) + }) + } +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1203,7 +1348,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1237,103 +1382,109 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { - t.Helper() + expectRouterEvent := func(addr tcpip.Address, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected router discovery event") } - default: - t.Fatal("expected router discovery event") } - } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { - t.Helper() + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for router discovery event") } - case <-time.After(timeout): - t.Fatal("timed out waiting for router discovery event") } - } - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Rx an RA from lladdr2 with zero lifetime. It should not be - // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") - default: - } - - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Rx an RA from another router (lladdr3) with non-zero lifetime. - const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Rx an RA from lladdr2 with lesser lifetime. - const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) - select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") - default: - } + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + default: + } - // Wait for lladdr2's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) - // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from another router (lladdr3) with non-zero lifetime. + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) + expectRouterEvent(llAddr3, true) - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + // Rx an RA from lladdr2 with lesser lifetime. + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + default: + } - // Wait for lladdr3's router invalidation job to execute. The lifetime - // of the router should have been updated to the most recent (smaller) - // lifetime. - // - // Wait for the normal lifetime plus an extra bit for the - // router to get invalidated. If we don't get an invalidation - // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + // Wait for lladdr2's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + expectRouterEvent(llAddr2, true) + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + expectRouterEvent(llAddr2, false) + + // Wait for lladdr3's router invalidation job to execute. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + }) } // TestRouterDiscoveryMaxRouters tests that only @@ -1347,7 +1498,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, }, NDPDisp: &ndpDisp, @@ -1386,57 +1537,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } -// TestNoPrefixDiscovery tests that prefix discovery will not be performed if -// configured not to. -func TestNoPrefixDiscovery(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - - // Being configured to discover prefixes means handle and - // discover are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, discover = - // true and forwarding = false (the required configuration to do - // prefix discovery) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - discover := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) - - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for prefix on nic with ID 1, and the // discovered flag set to discovered. func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { @@ -1455,8 +1555,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1494,87 +1593,93 @@ func TestPrefixDiscovery(t *testing.T) { prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { - t.Helper() + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { + t.Helper() - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected prefix discovery event") } - default: - t.Fatal("expected prefix discovery event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - default: - } + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - expectPrefixEvent(subnet1, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - expectPrefixEvent(subnet2, true) + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) + expectPrefixEvent(subnet1, true) - // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - expectPrefixEvent(subnet3, true) + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) + expectPrefixEvent(subnet2, true) - // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - expectPrefixEvent(subnet1, false) + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) + expectPrefixEvent(subnet3, true) - // Receive an RA with prefix2 in a PI with lesser lifetime. - lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly received prefix event when updating lifetime") - default: - } + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) + expectPrefixEvent(subnet1, false) - // Wait for prefix2's most recent invalidation job plus some buffer to - // expire. - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet2, false); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + default: } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for prefix discovery event") - } - // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - expectPrefixEvent(subnet3, false) + // Wait for prefix2's most recent invalidation job plus some buffer to + // expire. + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) + expectPrefixEvent(subnet3, false) + }) } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1590,7 +1695,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }() prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, } subnet := prefix.Subnet() @@ -1603,7 +1708,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverOnLinkPrefixes: true, }, NDPDisp: &ndpDisp, @@ -1688,7 +1793,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: false, DiscoverOnLinkPrefixes: true, }, @@ -1753,53 +1858,6 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) return containsAddr(list, protocolAddress) } -// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. -func TestNoAutoGenAddr(t *testing.T) { - prefix, _, _ := prefixSubnetAddr(0, "") - - // Being configured to auto-generate addresses means handle and - // autogen are set to true and forwarding is set to false. - // This tests all possible combinations of the configurations, - // except for the configuration where handle = true, autogen = - // true and forwarding = false (the required configuration to do - // SLAAC) - that will done in other tests. - for i := 0; i < 7; i++ { - handle := i&1 != 0 - autogen := i&2 != 0 - forwarding := i&4 == 0 - - t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, - })}, - }) - s.SetForwarding(ipv6.ProtocolNumber, forwarding) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) - - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when configured not to") - default: - } - }) - } -} - // Check e to make sure that the event is for addr on nic with ID 1, and the // event type is set to eventType. func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { @@ -1808,7 +1866,7 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr2(t *testing.T) { +func TestAutoGenAddr(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second saved := ipv6.MinPrefixInformationValidLifetimeForUpdate @@ -1820,96 +1878,102 @@ func TestAutoGenAddr2(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handleRAs, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - default: - t.Fatal("expected addr auto gen event") } - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - default: - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + default: + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - default: - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + default: + } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } - // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - default: - } + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + default: + } - // Wait for addr of prefix1 to be invalidated. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { - t.Fatalf("Should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { - t.Fatalf("Should have %s in the list of addresses", addr2) - } + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + }) } func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { @@ -1997,7 +2061,7 @@ func TestAutoGenTempAddr(t *testing.T) { RetransmitTimer: test.retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2298,7 +2362,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -2385,7 +2449,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2534,7 +2598,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, @@ -2735,7 +2799,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { Clock: clock, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, @@ -2880,7 +2944,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: ndpDisp, @@ -3347,7 +3411,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3490,7 +3554,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3557,7 +3621,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3723,7 +3787,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3805,7 +3869,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, NDPDisp: &ndpDisp, @@ -3969,7 +4033,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { @@ -3996,7 +4060,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { { name: "Temporary address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -4146,7 +4210,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { { name: "Global address", ndpConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4274,7 +4338,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { RetransmitTimer: retransmitTimer, }, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenAddressConflictRetries: maxRetries, }, @@ -4480,7 +4544,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4531,7 +4595,7 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -4625,8 +4689,110 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } } -// TestCleanupNDPState tests that all discovered routers and prefixes, and -// auto-generated addresses are invalidated when a NIC becomes a router. +func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { + const ( + lifetimeSeconds = 999 + nicID = 1 + ) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 1), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, 1), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenLinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + }) + + e1 := channel.New(0, header.IPv6MinimumMTU, linkAddr1) + if err := s.CreateNIC(nicID, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + llAddr := tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen} + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, llAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", llAddr, nicID) + } + + prefix, subnet, addr := prefixSubnetAddr(0, linkAddr1) + e1.InjectInbound( + header.IPv6ProtocolNumber, + raBufWithPI( + llAddr3, + lifetimeSeconds, + prefix, + true, /* onLink */ + true, /* auto */ + lifetimeSeconds, + lifetimeSeconds, + ), + ) + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + } + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr mismatch (-want +got):\n%s", diff) + } + default: + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", addr, nicID) + } + + // Enabling or disabling forwarding should not invalidate discovered prefixes + // or routers, or auto-generated address. + for _, forwarding := range [...]bool{true, false} { + t.Run(fmt.Sprintf("Transition forwarding to %t", forwarding), func(t *testing.T) { + if err := s.SetForwarding(ipv6.ProtocolNumber, forwarding); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) + } + select { + case e := <-ndpDisp.routerC: + t.Errorf("unexpected router event = %#v", e) + default: + } + select { + case e := <-ndpDisp.prefixC: + t.Errorf("unexpected prefix event = %#v", e) + default: + } + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("unexpected auto-gen addr event = %#v", e) + default: + } + }) + } +} + func TestCleanupNDPState(t *testing.T) { const ( lifetimeSeconds = 5 @@ -4655,18 +4821,6 @@ func TestCleanupNDPState(t *testing.T) { maxAutoGenAddrEvents int skipFinalAddrCheck bool }{ - // A NIC should still keep its auto-generated link-local address when - // becoming a router. - { - name: "Enable forwarding", - cleanupFn: func(t *testing.T, s *stack.Stack) { - t.Helper() - s.SetForwarding(ipv6.ProtocolNumber, true) - }, - keepAutoGenLinkLocal: true, - maxAutoGenAddrEvents: 4, - }, - // A NIC should cleanup all NDP state when it is disabled. { name: "Disable NIC", @@ -4718,7 +4872,7 @@ func TestCleanupNDPState(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, DiscoverDefaultRouters: true, DiscoverOnLinkPrefixes: true, AutoGenGlobalAddresses: true, @@ -4991,7 +5145,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, }, NDPDisp: &ndpDisp, })}, @@ -5182,96 +5336,127 @@ func TestRouterSolicitation(t *testing.T) { }, } + subTests := []struct { + name string + handleRAs ipv6.HandleRAsConfiguration + afterFirstRS func(*testing.T, *stack.Stack) + }{ + { + name: "Handle RAs when forwarding disabled", + handleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + afterFirstRS: func(*testing.T, *stack.Stack) {}, + }, + + // Enabling forwarding when RAs are always configured to be handled + // should not stop router solicitations. + { + name: "Handle RAs always", + handleRAs: ipv6.HandlingRAsAlwaysEnabled, + afterFirstRS: func(t *testing.T, s *stack.Stack) { + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + }, + }, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - clock := faketime.NewManualClock() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, - } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() + + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") + } - clock.Advance(timeout) - p, ok := e.Read() - if !ok { - t.Fatal("expected router solicitation packet") - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: subTest.handleRAs, + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - clock.Advance(timeout) - if p, ok := e.Read(); ok { - t.Fatalf("unexpectedly got a packet = %#v", p) - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - Clock: clock, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) - remaining-- - } + subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) - waitForPkt(time.Nanosecond) - } else { - waitForPkt(test.effectiveRtrSolicitInt) - } - } + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) + } else { + waitForPkt(test.effectiveRtrSolicitInt) + } + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt) - } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay) - } + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } - if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) } }) } @@ -5362,13 +5547,14 @@ func TestStopStartSolicitingRouters(t *testing.T) { } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, MaxRtrSolicitations: maxRtrSolicitations, RtrSolicitationInterval: interval, MaxRtrSolicitationDelay: delay, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 48bb75e2f..9821a18d3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1556,7 +1556,7 @@ func TestNeighborCacheRetryResolution(t *testing.T) { func BenchmarkCacheClear(b *testing.B) { b.StopTimer() config := DefaultNUDConfigurations() - clock := &tcpip.StdClock{} + clock := tcpip.NewStdClock() linkRes := newTestNeighborResolver(nil, config, clock) linkRes.delay = 0 diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index bb2b2d705..1d39ee73d 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -26,14 +26,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) const ( entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 entryTestNICID tcpip.NICID = 1 - entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") @@ -44,6 +43,11 @@ const ( entryTestNetDefaultMTU = 65536 ) +var ( + entryTestAddr1 = testutil.MustParse6("a::1") + entryTestAddr2 = testutil.MustParse6("a::2") +) + // runImmediatelyScheduledJobs runs all jobs scheduled to run at the current // time. func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca15c0691..8d615500f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -316,30 +316,30 @@ func (n *nic) IsLoopback() bool { } // WritePacket implements NetworkLinkEndpoint. -func (n *nic) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { - _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt) +func (n *nic) WritePacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { + _, err := n.enqueuePacketBuffer(r, protocol, pkt) return err } -func (n *nic) writePacketBuffer(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) writePacketBuffer(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { switch pkt := pkt.(type) { case *PacketBuffer: - if err := n.writePacket(r, gso, protocol, pkt); err != nil { + if err := n.writePacket(r, protocol, pkt); err != nil { return 0, err } return 1, nil case *PacketBufferList: - return n.writePackets(r, gso, protocol, *pkt) + return n.writePackets(r, protocol, *pkt) default: panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt)) } } -func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (n *nic) enqueuePacketBuffer(r *Route, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { routeInfo, _, err := r.resolvedFields(nil) switch err.(type) { case nil: - return n.writePacketBuffer(routeInfo, gso, protocol, pkt) + return n.writePacketBuffer(routeInfo, protocol, pkt) case *tcpip.ErrWouldBlock: // As per relevant RFCs, we should queue packets while we wait for link // resolution to complete. @@ -358,28 +358,27 @@ func (n *nic) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProt // SHOULD be limited to some small value. When a queue overflows, the new // arrival SHOULD replace the oldest entry. Once address resolution // completes, the node transmits any queued packets. - return n.linkResQueue.enqueue(r, gso, protocol, pkt) + return n.linkResQueue.enqueue(r, protocol, pkt) default: return 0, err } } // WritePacketToRemote implements NetworkInterface. -func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { var r RouteInfo r.NetProto = protocol r.RemoteLinkAddress = remoteLinkAddr - return n.writePacket(r, gso, protocol, pkt) + return n.writePacket(r, protocol, pkt) } -func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() pkt.EgressRoute = r - pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol - if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil { + if err := n.LinkEndpoint.WritePacket(r, protocol, pkt); err != nil { return err } @@ -389,18 +388,17 @@ func (n *nic) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN } // WritePackets implements NetworkLinkEndpoint. -func (n *nic) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { - return n.enqueuePacketBuffer(r, gso, protocol, &pkts) +func (n *nic) WritePackets(r *Route, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, tcpip.Error) { + return n.enqueuePacketBuffer(r, protocol, &pkts) } -func (n *nic) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { +func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { pkt.EgressRoute = r - pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol } - writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol) + writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index c0f956e53..8a3005295 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -65,12 +65,12 @@ func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { } // WritePacket implements NetworkEndpoint.WritePacket. -func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) tcpip.Error { +func (*testIPv6Endpoint) WritePacket(*Route, NetworkHeaderParams, *PacketBuffer) tcpip.Error { return nil } // WritePackets implements NetworkEndpoint.WritePackets. -func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { +func (*testIPv6Endpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { // Our tests don't use this so we don't support it. return 0, &tcpip.ErrNotSupported{} } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 8f288675d..9527416cf 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -103,7 +103,7 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. EgressRoute RouteInfo - GSOOptions *GSO + GSOOptions GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. @@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network { // See PacketBuffer.Data for details about how a packet buffer holds an inbound // packet. func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { - return NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), }) + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + if pk.NatDone { + newPk.NatDone = true + } + return newPk } // headerInfo stores metadata about a header in a packet. @@ -355,9 +364,10 @@ func (d PacketData) PullUp(size int) (buffer.View, bool) { return d.pk.data.PullUp(size) } -// TrimFront removes count from the beginning of d. It panics if count > -// d.Size(). -func (d PacketData) TrimFront(count int) { +// DeleteFront removes count from the beginning of d. It panics if count > +// d.Size(). All backing storage references after the front of the d are +// invalidated. +func (d PacketData) DeleteFront(count int) { d.pk.data.TrimFront(count) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 6728370c3..bd4eb4fed 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -112,23 +112,13 @@ func TestPacketHeaderPush(t *testing.T) { if got, want := pk.Size(), allHdrSize+len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - checkData(t, pk, test.data) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), - concatViews(test.link, test.network, test.transport, test.data)) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(test.link, test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(test.network, test.transport, test.data)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(test.transport, test.data)) + // Check the after state. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.link, + network: test.network, + transport: test.transport, + data: test.data, + }) }) } } @@ -199,29 +189,13 @@ func TestPacketHeaderConsume(t *testing.T) { if got, want := pk.Size(), len(test.data); got != want { t.Errorf("After pk.Size() = %d, want %d", got, want) } - // After state of pk. - var ( - link = test.data[:test.link] - network = test.data[test.link:][:test.network] - transport = test.data[test.link+test.network:][:test.transport] - payload = test.data[allHdrSize:] - ) - checkData(t, pk, payload) - checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) - // Check the after values for each header. - checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) - checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) - checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) - // Check the after values for PayloadSince. - checkViewEqual(t, "After PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), - concatViews(link, network, transport, payload)) - checkViewEqual(t, "After PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), - concatViews(network, transport, payload)) - checkViewEqual(t, "After PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), - concatViews(transport, payload)) + // Check the after state of pk. + checkPacketContents(t, "After ", pk, packetContents{ + link: test.data[:test.link], + network: test.data[test.link:][:test.network], + transport: test.data[test.link+test.network:][:test.transport], + data: test.data[allHdrSize:], + }) }) } } @@ -252,6 +226,39 @@ func TestPacketHeaderConsumeDataTooShort(t *testing.T) { }) } +// This is a very obscure use-case seen in the code that verifies packets +// before sending them out. It tries to parse the headers to verify. +// PacketHeader was initially not designed to mix Push() and Consume(), but it +// works and it's been relied upon. Include a test here. +func TestPacketHeaderPushConsumeMixed(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := append([]byte(nil), network...) + initData = append(initData, data...) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Consume network header + gotNetwork, ok := pk.NetworkHeader().Consume(len(network)) + if !ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = _, false; want _, true", len(network)) + } + checkViewEqual(t, "gotNetwork", gotNetwork, network) + + // 2. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + network: network, + data: data, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 @@ -397,11 +404,11 @@ func TestPacketBufferData(t *testing.T) { } }) - // TrimFront + // DeleteFront for _, n := range []int{1, len(tc.data)} { - t.Run(fmt.Sprintf("TrimFront%d", n), func(t *testing.T) { + t.Run(fmt.Sprintf("DeleteFront%d", n), func(t *testing.T) { pkt := tc.makePkt(t) - pkt.Data().TrimFront(n) + pkt.Data().DeleteFront(n) checkData(t, pkt, []byte(tc.data)[n:]) }) @@ -494,6 +501,37 @@ func TestPacketBufferData(t *testing.T) { } } +type packetContents struct { + link buffer.View + network buffer.View + transport buffer.View + data buffer.View +} + +func checkPacketContents(t *testing.T, prefix string, pk *PacketBuffer, want packetContents) { + t.Helper() + // Headers. + checkPacketHeader(t, prefix+"pk.LinkHeader", pk.LinkHeader(), want.link) + checkPacketHeader(t, prefix+"pk.NetworkHeader", pk.NetworkHeader(), want.network) + checkPacketHeader(t, prefix+"pk.TransportHeader", pk.TransportHeader(), want.transport) + // Data. + checkData(t, pk, want.data) + // Whole packet. + checkViewEqual(t, prefix+"pk.Views()", + concatViews(pk.Views()...), + concatViews(want.link, want.network, want.transport, want.data)) + // PayloadSince. + checkViewEqual(t, prefix+"PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(want.link, want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(want.network, want.transport, want.data)) + checkViewEqual(t, prefix+"PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(want.transport, want.data)) +} + func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { t.Helper() reserved := opts.ReserveHeaderBytes @@ -510,19 +548,9 @@ func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferO if got, want := pk.Size(), len(data); got != want { t.Errorf("Initial pk.Size() = %d, want %d", got, want) } - checkData(t, pk, data) - checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) - // Check the initial values for each header. - checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) - checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) - checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) - // Check the initial valies for PayloadSince. - checkViewEqual(t, "Initial PayloadSince(LinkHeader)", - PayloadSince(pk.LinkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", - PayloadSince(pk.NetworkHeader()), data) - checkViewEqual(t, "Initial PayloadSince(TransportHeader)", - PayloadSince(pk.TransportHeader()), data) + checkPacketContents(t, "Initial ", pk, packetContents{ + data: data, + }) } func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index e936aa728..13e8907ec 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -46,7 +46,6 @@ func (p *PacketBufferList) len() int { type pendingPacket struct { routeInfo RouteInfo - gso *GSO proto tcpip.NetworkProtocolNumber pkt pendingPacketBuffer } @@ -119,7 +118,7 @@ func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpi // If the maximum number of pending resolutions is reached, the packets // associated with the oldest link resolution will be dequeued as if they failed // link resolution. -func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { +func (f *packetsPendingLinkResolution) enqueue(r *Route, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, tcpip.Error) { f.mu.Lock() // Make sure we attempt resolution while holding f's lock so that we avoid // a race where link resolution completes before we enqueue the packets. @@ -137,7 +136,7 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N // The route resolved immediately, so we don't need to wait for link // resolution to send the packet. f.mu.Unlock() - return f.nic.writePacketBuffer(routeInfo, gso, proto, pkt) + return f.nic.writePacketBuffer(routeInfo, proto, pkt) case *tcpip.ErrWouldBlock: // We need to wait for link resolution to complete. default: @@ -150,7 +149,6 @@ func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.N packets, ok := f.mu.packets[ch] packets = append(packets, pendingPacket{ routeInfo: routeInfo, - gso: gso, proto: proto, pkt: pkt, }) @@ -211,7 +209,7 @@ func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, l for _, p := range packets { if err == nil { p.routeInfo.RemoteLinkAddress = linkAddr - _, _ = f.nic.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt) + _, _ = f.nic.writePacketBuffer(p.routeInfo, p.proto, p.pkt) } else { f.incrementOutgoingPacketErrors(p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ff3a385e1..e26225552 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -537,14 +537,14 @@ type NetworkInterface interface { CheckLocalAddress(tcpip.NetworkProtocolNumber, tcpip.Address) bool // WritePacketToRemote writes the packet to the given remote link address. - WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacketToRemote(tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePacket writes a packet with the given protocol through the given // route. // // WritePacket takes ownership of the packet buffer. The packet buffer's // network and transport header must be set. - WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacket(*Route, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol through the given // route. Must not be called with an empty list of packet buffers. @@ -554,7 +554,7 @@ type NetworkInterface interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + WritePackets(*Route, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) // HandleNeighborProbe processes an incoming neighbor probe (e.g. ARP // request or NDP Neighbor Solicitation). @@ -610,12 +610,12 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error + WritePacket(r *Route, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and // underlying packets. - WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) + WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. @@ -756,11 +756,6 @@ const ( CapabilitySaveRestore CapabilityDisconnectOk CapabilityLoopback - CapabilityHardwareGSO - - // CapabilitySoftwareGSO indicates the link endpoint supports of sending - // multiple packets using a single call (LinkEndpoint.WritePackets). - CapabilitySoftwareGSO ) // NetworkLinkEndpoint is a data-link layer that supports sending network @@ -832,7 +827,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(RouteInfo, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error + WritePacket(RouteInfo, tcpip.NetworkProtocolNumber, *PacketBuffer) tcpip.Error // WritePackets writes packets with the given protocol and route. Must not be // called with an empty list of packet buffers. @@ -842,7 +837,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, syscall filters // may need to be updated. - WritePackets(RouteInfo, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) + WritePackets(RouteInfo, PacketBufferList, tcpip.NetworkProtocolNumber) (int, tcpip.Error) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -1047,10 +1042,29 @@ type GSO struct { MaxSize uint32 } +// SupportedGSO returns the type of segmentation offloading supported. +type SupportedGSO int + +const ( + // GSONotSupported indicates that segmentation offloading is not supported. + GSONotSupported SupportedGSO = iota + + // HWGSOSupported indicates that segmentation offloading may be performed by + // the hardware. + HWGSOSupported + + // SWGSOSupported indicates that segmentation offloading may be performed in + // software. + SWGSOSupported +) + // GSOEndpoint provides access to GSO properties. type GSOEndpoint interface { // GSOMaxSize returns the maximum GSO packet size. GSOMaxSize() uint32 + + // SupportedGSO returns the supported segmentation offloading. + SupportedGSO() SupportedGSO } // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 39344808d..8a044c073 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp localAddr = addressEndpoint.AddressWithPrefix().Address } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) { addressEndpoint.DecRef() return nil } @@ -300,12 +300,18 @@ func (r *Route) RequiresTXTransportChecksum() bool { // HasSoftwareGSOCapability returns true if the route supports software GSO. func (r *Route) HasSoftwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == SWGSOSupported + } + return false } // HasHardwareGSOCapability returns true if the route supports hardware GSO. func (r *Route) HasHardwareGSOCapability() bool { - return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { + return gso.SupportedGSO() == HWGSOSupported + } + return false } // HasSaveRestoreCapability returns true if the route supports save/restore. @@ -448,22 +454,22 @@ func (r *Route) isValidForOutgoingRLocked() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { +func (r *Route) WritePacket(params NetworkHeaderParams, pkt *PacketBuffer) tcpip.Error { if !r.isValidForOutgoing() { return &tcpip.ErrInvalidEndpointState{} } - return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePacket(r, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. -func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (r *Route) WritePackets(pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { if !r.isValidForOutgoing() { return 0, &tcpip.ErrInvalidEndpointState{} } - return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto()).WritePackets(r, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 931a97ddc..436392f23 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -35,7 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,306 +55,6 @@ type transportProtocolState struct { defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } -// TCPProbeFunc is the expected function type for a TCP probe function to be -// passed to stack.AddTCPProbe. -type TCPProbeFunc func(s TCPEndpointState) - -// TCPCubicState is used to hold a copy of the internal cubic state when the -// TCPProbeFunc is invoked. -type TCPCubicState struct { - WLastMax float64 - WMax float64 - T time.Time - TimeSinceLastCongestion time.Duration - C float64 - K float64 - Beta float64 - WC float64 - WEst float64 -} - -// TCPRACKState is used to hold a copy of the internal RACK state when the -// TCPProbeFunc is invoked. -type TCPRACKState struct { - XmitTime time.Time - EndSequence seqnum.Value - FACK seqnum.Value - RTT time.Duration - Reord bool - DSACKSeen bool - ReoWnd time.Duration - ReoWndIncr uint8 - ReoWndPersist int8 - RTTSeq seqnum.Value -} - -// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. -type TCPEndpointID struct { - // LocalPort is the local port associated with the endpoint. - LocalPort uint16 - - // LocalAddress is the local [network layer] address associated with - // the endpoint. - LocalAddress tcpip.Address - - // RemotePort is the remote port associated with the endpoint. - RemotePort uint16 - - // RemoteAddress it the remote [network layer] address associated with - // the endpoint. - RemoteAddress tcpip.Address -} - -// TCPFastRecoveryState holds a copy of the internal fast recovery state of a -// TCP endpoint. -type TCPFastRecoveryState struct { - // Active if true indicates the endpoint is in fast recovery. - Active bool - - // First is the first unacknowledged sequence number being recovered. - First seqnum.Value - - // Last is the 'recover' sequence number that indicates the point at - // which we should exit recovery barring any timeouts etc. - Last seqnum.Value - - // MaxCwnd is the maximum value we are permitted to grow the congestion - // window during recovery. This is set at the time we enter recovery. - MaxCwnd int - - // HighRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - HighRxt seqnum.Value - - // RescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - RescueRxt seqnum.Value -} - -// TCPReceiverState holds a copy of the internal state of the receiver for -// a given TCP endpoint. -type TCPReceiverState struct { - // RcvNxt is the TCP variable RCV.NXT. - RcvNxt seqnum.Value - - // RcvAcc is the TCP variable RCV.ACC. - RcvAcc seqnum.Value - - // RcvWndScale is the window scaling to use for inbound segments. - RcvWndScale uint8 - - // PendingBufUsed is the number of bytes pending in the receive - // queue. - PendingBufUsed int -} - -// TCPSenderState holds a copy of the internal state of the sender for -// a given TCP Endpoint. -type TCPSenderState struct { - // LastSendTime is the time at which we sent the last segment. - LastSendTime time.Time - - // DupAckCount is the number of Duplicate ACK's received. - DupAckCount int - - // SndCwnd is the size of the sending congestion window in packets. - SndCwnd int - - // Ssthresh is the slow start threshold in packets. - Ssthresh int - - // SndCAAckCount is the number of packets consumed in congestion - // avoidance mode. - SndCAAckCount int - - // Outstanding is the number of packets in flight. - Outstanding int - - // SackedOut is the number of packets which have been selectively acked. - SackedOut int - - // SndWnd is the send window size in bytes. - SndWnd seqnum.Size - - // SndUna is the next unacknowledged sequence number. - SndUna seqnum.Value - - // SndNxt is the sequence number of the next segment to be sent. - SndNxt seqnum.Value - - // RTTMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - RTTMeasureSeqNum seqnum.Value - - // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time - - // Closed indicates that the caller has closed the endpoint for sending. - Closed bool - - // SRTT is the smoothed round-trip time as defined in section 2 of - // RFC 6298. - SRTT time.Duration - - // RTO is the retransmit timeout as defined in section of 2 of RFC 6298. - RTO time.Duration - - // RTTVar is the round-trip time variation as defined in section 2 of - // RFC 6298. - RTTVar time.Duration - - // SRTTInited if true indicates take a valid RTT measurement has been - // completed. - SRTTInited bool - - // MaxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - MaxPayloadSize int - - // SndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - SndWndScale uint8 - - // MaxSentAck is the highest acknowledgement number sent till now. - MaxSentAck seqnum.Value - - // FastRecovery holds the fast recovery state for the endpoint. - FastRecovery TCPFastRecoveryState - - // Cubic holds the state related to CUBIC congestion control. - Cubic TCPCubicState - - // RACKState holds the state related to RACK loss detection algorithm. - RACKState TCPRACKState -} - -// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. -type TCPSACKInfo struct { - // Blocks is the list of SACK Blocks that identify the out of order segments - // held by a given TCP endpoint. - Blocks []header.SACKBlock - - // ReceivedBlocks are the SACK blocks received by this endpoint - // from the peer endpoint. - ReceivedBlocks []header.SACKBlock - - // MaxSACKED is the highest sequence number that has been SACKED - // by the peer. - MaxSACKED seqnum.Value -} - -// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. -type RcvBufAutoTuneParams struct { - // MeasureTime is the time at which the current measurement - // was started. - MeasureTime time.Time - - // CopiedBytes is the number of bytes copied to user space since - // this measure began. - CopiedBytes int - - // PrevCopiedBytes is the number of bytes copied to userspace in - // the previous RTT period. - PrevCopiedBytes int - - // RcvBufSize is the auto tuned receive buffer size. - RcvBufSize int - - // RTT is the smoothed RTT as measured by observing the time between - // when a byte is first acknowledged and the receipt of data that is at - // least one window beyond the sequence number that was acknowledged. - RTT time.Duration - - // RTTVar is the "round-trip time variation" as defined in section 2 - // of RFC6298. - RTTVar time.Duration - - // RTTMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - RTTMeasureSeqNumber seqnum.Value - - // RTTMeasureTime is the absolute time at which the current RTT - // measurement period began. - RTTMeasureTime time.Time - - // Disabled is true if an explicit receive buffer is set for the - // endpoint. - Disabled bool -} - -// TCPEndpointState is a copy of the internal state of a TCP endpoint. -type TCPEndpointState struct { - // ID is a copy of the TransportEndpointID for the endpoint. - ID TCPEndpointID - - // SegTime denotes the absolute time when this segment was received. - SegTime time.Time - - // RcvBufSize is the size of the receive socket buffer for the endpoint. - RcvBufSize int - - // RcvBufUsed is the amount of bytes actually held in the receive socket - // buffer for the endpoint. - RcvBufUsed int - - // RcvBufAutoTuneParams is used to hold state variables to compute - // the auto tuned receive buffer size. - RcvAutoParams RcvBufAutoTuneParams - - // RcvClosed if true, indicates the endpoint has been closed for reading. - RcvClosed bool - - // SendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1. - SendTSOk bool - - // RecentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - RecentTS uint32 - - // TSOffset is a randomized offset added to the value of the TSVal field - // in the timestamp option. - TSOffset uint32 - - // SACKPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - SACKPermitted bool - - // SACK holds TCP SACK related information for this endpoint. - SACK TCPSACKInfo - - // SndBufSize is the size of the socket send buffer. - SndBufSize int - - // SndBufUsed is the number of bytes held in the socket send buffer. - SndBufUsed int - - // SndClosed indicates that the endpoint has been closed for sends. - SndClosed bool - - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - - // PacketTooBigCount is used to notify the main protocol routine how - // many times a "packet too big" control packet is received. - PacketTooBigCount int - - // SndMTU is the smallest MTU seen in the control packets received. - SndMTU int - - // Receiver holds variables related to the TCP receiver for the endpoint. - Receiver TCPReceiverState - - // Sender holds state related to the TCP Sender for the endpoint. - Sender TCPSenderState -} - // ResumableEndpoint is an endpoint that needs to be resumed after restore. type ResumableEndpoint interface { // Resume resumes an endpoint after restore. This can be used to restart @@ -455,7 +154,7 @@ type Stack struct { // receiveBufferSize holds the min/default/max receive buffer sizes for // endpoints other than TCP. - receiveBufferSize ReceiveBufferSizeOption + receiveBufferSize tcpip.ReceiveBufferSizeOption // tcpInvalidRateLimit is the maximal rate for sending duplicate // acknowledgements in response to incoming TCP packets that are for an existing @@ -623,7 +322,7 @@ func (*TransportEndpointInfo) IsEndpointInfo() {} func New(opts Options) *Stack { clock := opts.Clock if clock == nil { - clock = &tcpip.StdClock{} + clock = tcpip.NewStdClock() } if opts.UniqueID == nil { @@ -669,7 +368,7 @@ func New(opts Options) *Stack { Default: DefaultBufferSize, Max: DefaultMaxBufferSize, }, - receiveBufferSize: ReceiveBufferSizeOption{ + receiveBufferSize: tcpip.ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, Max: DefaultMaxBufferSize, @@ -1344,7 +1043,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) + isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) @@ -1381,7 +1080,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1874,7 +1573,7 @@ func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, ReserveHeaderBytes: int(nic.MaxHeaderLength()), Data: payload, }) - return nic.WritePacketToRemote(remote, nil, netProto, pkt) + return nic.WritePacketToRemote(remote, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index dfec4258a..33824afd0 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,6 +14,78 @@ package stack +import "time" + // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack + +// saveT is invoked by stateify. +func (t *TCPCubicState) saveT() unixTime { + return unixTime{t.T.Unix(), t.T.UnixNano()} +} + +// loadT is invoked by stateify. +func (t *TCPCubicState) loadT(unix unixTime) { + t.T = time.Unix(unix.second, unix.nano) +} + +// saveXmitTime is invoked by stateify. +func (t *TCPRACKState) saveXmitTime() unixTime { + return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (t *TCPRACKState) loadXmitTime(unix unixTime) { + t.XmitTime = time.Unix(unix.second, unix.nano) +} + +// saveLastSendTime is invoked by stateify. +func (t *TCPSenderState) saveLastSendTime() unixTime { + return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} +} + +// loadLastSendTime is invoked by stateify. +func (t *TCPSenderState) loadLastSendTime(unix unixTime) { + t.LastSendTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) saveRTTMeasureTime() unixTime { + return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { + t.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { + return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} +} + +// loadMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { + r.MeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { + return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} +} + +// loadRTTMeasureTime is invoked by stateify. +func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { + r.RTTMeasureTime = time.Unix(unix.second, unix.nano) +} + +// saveSegTime is invoked by stateify. +func (t *TCPEndpointState) saveSegTime() unixTime { + return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} +} + +// loadSegTime is invoked by stateify. +func (t *TCPEndpointState) loadSegTime(unix unixTime) { + t.SegTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go index 3066f4ffd..80e8e0089 100644 --- a/pkg/tcpip/stack/stack_options.go +++ b/pkg/tcpip/stack/stack_options.go @@ -68,7 +68,7 @@ func (s *Stack) SetOption(option interface{}) tcpip.Error { s.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: // Make sure we don't allow lowering the buffer below minimum // required for stack to work. if v.Min < MinBufferSize { @@ -107,7 +107,7 @@ func (s *Stack) Option(option interface{}) tcpip.Error { s.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.ReceiveBufferSizeOption: s.mu.RLock() *v = s.receiveBufferSize s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 2814b94b4..d2c40cc43 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -39,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) @@ -137,11 +138,13 @@ func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data().PullUp(fakeNetHeaderLen) + hdr, ok := pkt.Data().PullUp(fakeNetHeaderLen) if !ok { return } - pkt.Data().TrimFront(fakeNetHeaderLen) + // DeleteFront invalidates slices. Make a copy before trimming. + nb := append([]byte(nil), hdr...) + pkt.Data().DeleteFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportError( tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), @@ -170,7 +173,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress()[0])%len(f.proto.sendPacketCount)]++ @@ -189,11 +192,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params return nil } - return f.nic.WritePacket(r, gso, fakeNetNumber, pkt) + return f.nic.WritePacket(r, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -436,7 +439,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) tcpip.Error } func send(r *stack.Route, payload buffer.View) tcpip.Error { - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + return r.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), })) @@ -1461,7 +1464,7 @@ func TestExternalSendWithHandleLocal(t *testing.T) { if n := ep.Drain(); n != 0 { t.Fatalf("got ep.Drain() = %d, want = 0", n) } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS, @@ -1645,10 +1648,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} - nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01") + nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} - nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01") + nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. s := stack.New(stack.Options{ @@ -2789,25 +2792,27 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") - ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") - toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 lifetimeSeconds = 9999 ) + var ( + linkLocalAddr1 = testutil.MustParse6("fe80::1") + linkLocalAddr2 = testutil.MustParse6("fe80::2") + linkLocalMulticastAddr = testutil.MustParse6("ff02::1") + uniqueLocalAddr1 = testutil.MustParse6("fc00::1") + uniqueLocalAddr2 = testutil.MustParse6("fd00::2") + globalAddr1 = testutil.MustParse6("a000::1") + globalAddr2 = testutil.MustParse6("a000::2") + globalAddr3 = testutil.MustParse6("a000::3") + ipv4MappedIPv6Addr1 = testutil.MustParse6("::ffff:0.0.0.1") + ipv4MappedIPv6Addr2 = testutil.MustParse6("::ffff:0.0.0.2") + toredoAddr1 = testutil.MustParse6("2001::1") + toredoAddr2 = testutil.MustParse6("2001::2") + ipv6ToIPv4Addr1 = testutil.MustParse6("2002::1") + ipv6ToIPv4Addr2 = testutil.MustParse6("2002::2") + ) + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) @@ -3017,7 +3022,7 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: true, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, @@ -3354,21 +3359,21 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { const sMin = stack.MinBufferSize testCases := []struct { name string - rs stack.ReceiveBufferSizeOption + rs tcpip.ReceiveBufferSizeOption err tcpip.Error }{ // Invalid configurations. - {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, - {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, - {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"min_below_zero", tcpip.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"min_zero", tcpip.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"default_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, + {"default_above_max", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, &tcpip.ErrInvalidOptionValue{}}, + {"max_below_min", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, &tcpip.ErrInvalidOptionValue{}}, // Valid Configurations - {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, - {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, - {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, - {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + {"in_ascending_order", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", tcpip.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -3377,7 +3382,7 @@ func TestStackReceiveBufferSizeOption(t *testing.T) { if err := s.SetOption(tc.rs); err != tc.err { t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if tc.err == nil { if err := s.Option(&rs); err != nil { t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) @@ -3448,7 +3453,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, @@ -4352,13 +4357,15 @@ func TestWritePacketToRemote(t *testing.T) { func TestClearNeighborCacheOnNICDisable(t *testing.T) { const ( - nicID = 1 - - ipv4Addr = tcpip.Address("\x01\x02\x03\x04") - ipv6Addr = tcpip.Address("\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04\x01\x02\x03\x04") + nicID = 1 linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") ) + var ( + ipv4Addr = testutil.MustParse4("1.2.3.4") + ipv6Addr = testutil.MustParse6("102:304:102:304:102:304:102:304") + ) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go new file mode 100644 index 000000000..ddff6e2d6 --- /dev/null +++ b/pkg/tcpip/stack/tcp.go @@ -0,0 +1,451 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// TCPProbeFunc is the expected function type for a TCP probe function to be +// passed to stack.AddTCPProbe. +type TCPProbeFunc func(s TCPEndpointState) + +// TCPCubicState is used to hold a copy of the internal cubic state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPCubicState struct { + // WLastMax is the previous wMax value. + WLastMax float64 + + // WMax is the value of the congestion window at the time of the last + // congestion event. + WMax float64 + + // T is the time when the current congestion avoidance was entered. + T time.Time `state:".(unixTime)"` + + // TimeSinceLastCongestion denotes the time since the current + // congestion avoidance was entered. + TimeSinceLastCongestion time.Duration + + // C is the cubic constant as specified in RFC8312, page 11. + C float64 + + // K is the time period (in seconds) that the above function takes to + // increase the current window size to WMax if there are no further + // congestion events and is calculated using the following equation: + // + // K = cubic_root(WMax*(1-beta_cubic)/C) (Eq. 2, page 5) + K float64 + + // Beta is the CUBIC multiplication decrease factor. That is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // WC(0)=WMax*beta_cubic. + Beta float64 + + // WC is window computed by CUBIC at time TimeSinceLastCongestion. It's + // calculated using the formula: + // + // WC(TimeSinceLastCongestion) = C*(t-K)^3 + WMax (Eq. 1) + WC float64 + + // WEst is the window computed by CUBIC at time + // TimeSinceLastCongestion+RTT i.e WC(TimeSinceLastCongestion+RTT). + WEst float64 +} + +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +// +// +stateify savable +type TCPRACKState struct { + // XmitTime is the transmission timestamp of the most recent + // acknowledged segment. + XmitTime time.Time `state:".(unixTime)"` + + // EndSequence is the ending TCP sequence number of the most recent + // acknowledged segment. + EndSequence seqnum.Value + + // FACK is the highest selectively or cumulatively acknowledged + // sequence. + FACK seqnum.Value + + // RTT is the round trip time of the most recently delivered packet on + // the connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + RTT time.Duration + + // Reord is true iff reordering has been detected on this connection. + Reord bool + + // DSACKSeen is true iff the connection has seen a DSACK. + DSACKSeen bool + + // ReoWnd is the reordering window time used for recording packet + // transmission times. It is used to defer the moment at which RACK + // marks a packet lost. + ReoWnd time.Duration + + // ReoWndIncr is the multiplier applied to adjust reorder window. + ReoWndIncr uint8 + + // ReoWndPersist is the number of loss recoveries before resetting + // reorder window. + ReoWndPersist int8 + + // RTTSeq is the SND.NXT when RTT is updated. + RTTSeq seqnum.Value +} + +// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. +// +// +stateify savable +type TCPEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// TCPFastRecoveryState holds a copy of the internal fast recovery state of a +// TCP endpoint. +// +// +stateify savable +type TCPFastRecoveryState struct { + // Active if true indicates the endpoint is in fast recovery. The + // following fields are only meaningful when Active is true. + Active bool + + // First is the first unacknowledged sequence number being recovered. + First seqnum.Value + + // Last is the 'recover' sequence number that indicates the point at + // which we should exit recovery barring any timeouts etc. + Last seqnum.Value + + // MaxCwnd is the maximum value we are permitted to grow the congestion + // window during recovery. This is set at the time we enter recovery. + // It exists to avoid attacks where the receiver intentionally sends + // duplicate acks to artificially inflate the sender's cwnd. + MaxCwnd int + + // HighRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. See: RFC 6675 Section 2 for + // details. + HighRxt seqnum.Value + + // RescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. See: RFC 6675 Section 2 for details. + RescueRxt seqnum.Value +} + +// TCPReceiverState holds a copy of the internal state of the receiver for a +// given TCP endpoint. +// +// +stateify savable +type TCPReceiverState struct { + // RcvNxt is the TCP variable RCV.NXT. + RcvNxt seqnum.Value + + // RcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to its + // peer that it's willing to accept. This may be different than RcvNxt + // + (last advertised receive window) if the receive window is reduced; + // in that case we have to reduce the window as we receive more data + // instead of shrinking it. + RcvAcc seqnum.Value + + // RcvWndScale is the window scaling to use for inbound segments. + RcvWndScale uint8 + + // PendingBufUsed is the number of bytes pending in the receive queue. + PendingBufUsed int +} + +// TCPRTTState holds a copy of information about the endpoint's round trip +// time. +// +// +stateify savable +type TCPRTTState struct { + // SRTT is the smoothed round trip time defined in section 2 of RFC + // 6298. + SRTT time.Duration + + // RTTVar is the round-trip time variation as defined in section 2 of + // RFC 6298. + RTTVar time.Duration + + // SRTTInited if true indicates that a valid RTT measurement has been + // completed. + SRTTInited bool +} + +// TCPSenderState holds a copy of the internal state of the sender for a given +// TCP Endpoint. +// +// +stateify savable +type TCPSenderState struct { + // LastSendTime is the timestamp at which we sent the last segment. + LastSendTime time.Time `state:".(unixTime)"` + + // DupAckCount is the number of Duplicate ACKs received. It is used for + // fast retransmit. + DupAckCount int + + // SndCwnd is the size of the sending congestion window in packets. + SndCwnd int + + // Ssthresh is the threshold between slow start and congestion + // avoidance. + Ssthresh int + + // SndCAAckCount is the number of packets acknowledged during + // congestion avoidance. When enough packets have been ack'd (typically + // cwnd packets), the congestion window is incremented by one. + SndCAAckCount int + + // Outstanding is the number of packets that have been sent but not yet + // acknowledged. + Outstanding int + + // SackedOut is the number of packets which have been selectively + // acked. + SackedOut int + + // SndWnd is the send window size in bytes. + SndWnd seqnum.Size + + // SndUna is the next unacknowledged sequence number. + SndUna seqnum.Value + + // SndNxt is the sequence number of the next segment to be sent. + SndNxt seqnum.Value + + // RTTMeasureSeqNum is the sequence number being used for the latest + // RTT measurement. + RTTMeasureSeqNum seqnum.Value + + // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Closed indicates that the caller has closed the endpoint for + // sending. + Closed bool + + // RTO is the retransmit timeout as defined in section of 2 of RFC + // 6298. + RTO time.Duration + + // RTTState holds information about the endpoint's round trip time. + RTTState TCPRTTState + + // MaxPayloadSize is the maximum size of the payload of a given + // segment. It is initialized on demand. + MaxPayloadSize int + + // SndWndScale is the number of bits to shift left when reading the + // send window size from a segment. + SndWndScale uint8 + + // MaxSentAck is the highest acknowledgement number sent till now. + MaxSentAck seqnum.Value + + // FastRecovery holds the fast recovery state for the endpoint. + FastRecovery TCPFastRecoveryState + + // Cubic holds the state related to CUBIC congestion control. + Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState +} + +// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. +// +// +stateify savable +type TCPSACKInfo struct { + // Blocks is the list of SACK Blocks that identify the out of order + // segments held by a given TCP endpoint. + Blocks []header.SACKBlock + + // ReceivedBlocks are the SACK blocks received by this endpoint from + // the peer endpoint. + ReceivedBlocks []header.SACKBlock + + // MaxSACKED is the highest sequence number that has been SACKED by the + // peer. + MaxSACKED seqnum.Value +} + +// RcvBufAutoTuneParams holds state related to TCP receive buffer auto-tuning. +// +// +stateify savable +type RcvBufAutoTuneParams struct { + // MeasureTime is the time at which the current measurement was + // started. + MeasureTime time.Time `state:".(unixTime)"` + + // CopiedBytes is the number of bytes copied to user space since this + // measure began. + CopiedBytes int + + // PrevCopiedBytes is the number of bytes copied to userspace in the + // previous RTT period. + PrevCopiedBytes int + + // RcvBufSize is the auto tuned receive buffer size. + RcvBufSize int + + // RTT is the smoothed RTT as measured by observing the time between + // when a byte is first acknowledged and the receipt of data that is at + // least one window beyond the sequence number that was acknowledged. + RTT time.Duration + + // RTTVar is the "round-trip time variation" as defined in section 2 of + // RFC6298. + RTTVar time.Duration + + // RTTMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + RTTMeasureSeqNumber seqnum.Value + + // RTTMeasureTime is the absolute time at which the current RTT + // measurement period began. + RTTMeasureTime time.Time `state:".(unixTime)"` + + // Disabled is true if an explicit receive buffer is set for the + // endpoint. + Disabled bool +} + +// TCPRcvBufState contains information about the state of an endpoint's receive +// socket buffer. +// +// +stateify savable +type TCPRcvBufState struct { + // RcvBufUsed is the amount of bytes actually held in the receive + // socket buffer for the endpoint. + RcvBufUsed int + + // RcvBufAutoTuneParams is used to hold state variables to compute the + // auto tuned receive buffer size. + RcvAutoParams RcvBufAutoTuneParams + + // RcvClosed if true, indicates the endpoint has been closed for + // reading. + RcvClosed bool +} + +// TCPSndBufState contains information about the state of an endpoint's send +// socket buffer. +// +// +stateify savable +type TCPSndBufState struct { + // SndBufSize is the size of the socket send buffer. + SndBufSize int + + // SndBufUsed is the number of bytes held in the socket send buffer. + SndBufUsed int + + // SndClosed indicates that the endpoint has been closed for sends. + SndClosed bool + + // SndBufInQueue is the number of bytes in the send queue. + SndBufInQueue seqnum.Size + + // PacketTooBigCount is used to notify the main protocol routine how + // many times a "packet too big" control packet is received. + PacketTooBigCount int + + // SndMTU is the smallest MTU seen in the control packets received. + SndMTU int +} + +// TCPEndpointStateInner contains the members of TCPEndpointState used directly +// (that is, not within another containing struct) within the endpoint's +// internal implementation. +// +// +stateify savable +type TCPEndpointStateInner struct { + // TSOffset is a randomized offset added to the value of the TSVal + // field in the timestamp option. + TSOffset uint32 + + // SACKPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + SACKPermitted bool + + // SendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1. + SendTSOk bool + + // RecentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field + // is updated if required when a new segment is received by this + // endpoint. + RecentTS uint32 +} + +// TCPEndpointState is a copy of the internal state of a TCP endpoint. +// +// +stateify savable +type TCPEndpointState struct { + // TCPEndpointStateInner contains the members of TCPEndpointState used + // by the endpoint's internal implementation. + TCPEndpointStateInner + + // ID is a copy of the TransportEndpointID for the endpoint. + ID TCPEndpointID + + // SegTime denotes the absolute time when this segment was received. + SegTime time.Time `state:".(unixTime)"` + + // RcvBufState contains information about the state of the endpoint's + // receive socket buffer. + RcvBufState TCPRcvBufState + + // SndBufState contains information about the state of the endpoint's + // send socket buffer. + SndBufState TCPSndBufState + + // SACK holds TCP SACK related information for this endpoint. + SACK TCPSACKInfo + + // Receiver holds variables related to the TCP receiver for the + // endpoint. + Receiver TCPReceiverState + + // Sender holds state related to the TCP Sender for the endpoint. + Sender TCPSenderState +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e188efccb..80ad1a9d4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -150,16 +150,17 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { return eps } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { +// handlePacket is called by the stack when new packets arrive to this transport +// endpoint. It returns false if the packet could not be matched to any +// transport endpoint, true otherwise. +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return false } } @@ -168,18 +169,19 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. - return + return true } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() - return + return true } transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. + return true } // handleError delivers an error to the transport endpoint identified by id. @@ -567,8 +569,7 @@ func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, } return false } - ep.handlePacket(id, pkt) - return true + return ep.handlePacket(id, pkt) } // deliverRawPacket attempts to deliver the given packet and returns whether it diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 054cced0c..839178809 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -70,7 +70,7 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint { ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()} - ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) return ep } @@ -106,7 +106,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions Data: buffer.View(v).ToVectorisedView(), }) _ = pkt.TransportHeader().Push(fakeTransHeaderLen) - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := f.route.WritePacket(stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, err } @@ -233,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * peerAddr: route.RemoteAddress(), route: route, } - ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) f.acceptQueue = append(f.acceptQueue, ep) } diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go new file mode 100644 index 000000000..7ce43a68e --- /dev/null +++ b/pkg/tcpip/stdclock.go @@ -0,0 +1,130 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +// stdClock implements Clock with the time package. +// +// +stateify savable +type stdClock struct { + // baseTime holds the time when the clock was constructed. + // + // This value is used to calculate the monotonic time from the time package. + // As per https://golang.org/pkg/time/#hdr-Monotonic_Clocks, + // + // Operating systems provide both a “wall clock,” which is subject to + // changes for clock synchronization, and a “monotonic clock,” which is not. + // The general rule is that the wall clock is for telling time and the + // monotonic clock is for measuring time. Rather than split the API, in this + // package the Time returned by time.Now contains both a wall clock reading + // and a monotonic clock reading; later time-telling operations use the wall + // clock reading, but later time-measuring operations, specifically + // comparisons and subtractions, use the monotonic clock reading. + // + // ... + // + // If Times t and u both contain monotonic clock readings, the operations + // t.After(u), t.Before(u), t.Equal(u), and t.Sub(u) are carried out using + // the monotonic clock readings alone, ignoring the wall clock readings. If + // either t or u contains no monotonic clock reading, these operations fall + // back to using the wall clock readings. + // + // Given the above, we can safely conclude that time.Since(baseTime) will + // return monotonically increasing values if we use time.Now() to set baseTime + // at the time of clock construction. + // + // Note that time.Since(t) is shorthand for time.Now().Sub(t), as per + // https://golang.org/pkg/time/#Since. + baseTime time.Time `state:"nosave"` + + // monotonicOffset is the offset applied to the calculated monotonic time. + // + // monotonicOffset is assigned maxMonotonic after restore so that the + // monotonic time will continue from where it "left off" before saving as part + // of S/R. + monotonicOffset int64 `state:"nosave"` + + // monotonicMU protects maxMonotonic. + monotonicMU sync.Mutex `state:"nosave"` + maxMonotonic int64 +} + +// NewStdClock returns an instance of a clock that uses the time package. +func NewStdClock() Clock { + return &stdClock{ + baseTime: time.Now(), + } +} + +var _ Clock = (*stdClock)(nil) + +// NowNanoseconds implements Clock.NowNanoseconds. +func (*stdClock) NowNanoseconds() int64 { + return time.Now().UnixNano() +} + +// NowMonotonic implements Clock.NowMonotonic. +func (s *stdClock) NowMonotonic() int64 { + sinceBase := time.Since(s.baseTime) + if sinceBase < 0 { + panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) + } + + monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + + // Monotonic time values must never decrease. + if monotonicValue > s.maxMonotonic { + s.maxMonotonic = monotonicValue + } + + return s.maxMonotonic +} + +// AfterFunc implements Clock.AfterFunc. +func (*stdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/stdclock_state.go index d0f58cfaf..795db9181 100644 --- a/pkg/tcpip/transport/tcp/cubic_state.go +++ b/pkg/tcpip/stdclock_state.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2021 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,18 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tcp +package tcpip -import ( - "time" -) +import "time" -// saveT is invoked by stateify. -func (c *cubicState) saveT() unixTime { - return unixTime{c.t.Unix(), c.t.UnixNano()} -} +// afterLoad is invoked by stateify. +func (s *stdClock) afterLoad() { + s.baseTime = time.Now() -// loadT is invoked by stateify. -func (c *cubicState) loadT(unix unixTime) { - c.t = time.Unix(unix.second, unix.nano) + s.monotonicMU.Lock() + defer s.monotonicMU.Unlock() + s.monotonicOffset = s.maxMonotonic } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 87ea09a5e..d5f941c5f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -73,7 +73,7 @@ type Clock interface { // nanoseconds since the Unix epoch. NowNanoseconds() int64 - // NowMonotonic returns a monotonic time value. + // NowMonotonic returns a monotonic time value at nanosecond resolution. NowMonotonic() int64 // AfterFunc waits for the duration to elapse and then calls f in its own @@ -691,10 +691,6 @@ const ( // number of unread bytes in the input buffer should be returned. ReceiveQueueSizeOption - // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to - // specify the receive buffer size option. - ReceiveBufferSizeOption - // SendQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the output buffer should be returned. SendQueueSizeOption @@ -786,6 +782,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {} func (*TCPRecovery) isSettableTransportProtocolOption() {} +// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies. +type TCPAlwaysUseSynCookies bool + +func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {} + +func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {} + const ( // TCPRACKLossDetection indicates RACK is used for loss detection and // recovery. @@ -1020,19 +1023,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} -// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify -// the number of endpoints that can be in SYN-RCVD state before the stack -// switches to using SYN cookies. -type TCPSynRcvdCountThresholdOption uint64 - -func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} - // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 @@ -1117,6 +1107,7 @@ const ( // LingerOption is used by SetSockOpt/GetSockOpt to set/get the // duration for which a socket lingers before returning from Close. // +// +marshal // +stateify savable type LingerOption struct { Enabled bool @@ -1150,6 +1141,19 @@ type SendBufferSizeOption struct { Max int } +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + // Min is the minimum size for send buffer. + Min int + + // Default is the default size for send buffer. + Default int + + // Max is the maximum size for send buffer. + Max int +} + // GetSendBufferLimits is used to get the send buffer size limits. type GetSendBufferLimits func(StackHandler) SendBufferSizeOption @@ -1162,6 +1166,18 @@ func GetStackSendBufferLimits(so StackHandler) SendBufferSizeOption { return ss } +// GetReceiveBufferLimits is used to get the send buffer size limits. +type GetReceiveBufferLimits func(StackHandler) ReceiveBufferSizeOption + +// GetStackReceiveBufferLimits is used to get default, min and max send buffer size. +func GetStackReceiveBufferLimits(so StackHandler) ReceiveBufferSizeOption { + var ss ReceiveBufferSizeOption + if err := so.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + return ss +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. @@ -1218,7 +1234,7 @@ func (s *StatCounter) Decrement() { } // Value returns the current value of the counter. -func (s *StatCounter) Value() uint64 { +func (s *StatCounter) Value(name ...string) uint64 { return atomic.LoadUint64(&s.count) } @@ -1512,6 +1528,30 @@ type IGMPStats struct { // LINT.ThenChange(network/ipv4/stats.go:multiCounterIGMPStats) } +// IPForwardingStats collects stats related to IP forwarding (both v4 and v6). +type IPForwardingStats struct { + // Unrouteable is the number of IP packets received which were dropped + // because the netstack could not construct a route to their + // destination. + Unrouteable *StatCounter + + // ExhaustedTTL is the number of IP packets received which were dropped + // because their TTL was exhausted. + ExhaustedTTL *StatCounter + + // LinkLocalSource is the number of IP packets which were dropped + // because they contained a link-local source address. + LinkLocalSource *StatCounter + + // LinkLocalDestination is the number of IP packets which were dropped + // because they contained a link-local destination address. + LinkLocalDestination *StatCounter + + // Errors is the number of IP packets received which could not be + // successfully forwarded. + Errors *StatCounter +} + // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // LINT.IfChange(IPStats) @@ -1562,6 +1602,10 @@ type IPStats struct { // chain. IPTablesOutputDropped *StatCounter + // IPTablesPostroutingDropped is the number of IP packets dropped in the + // Postrouting chain. + IPTablesPostroutingDropped *StatCounter + // TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out // of IPStats. // OptionTimestampReceived is the number of Timestamp options seen. @@ -1576,6 +1620,9 @@ type IPStats struct { // OptionUnknownReceived is the number of unknown IP options seen. OptionUnknownReceived *StatCounter + // Forwarding collects stats related to IP forwarding. + Forwarding IPForwardingStats + // LINT.ThenChange(network/internal/ip/stats.go:MultiCounterIPStats) } @@ -1734,6 +1781,10 @@ type TCPStats struct { // ChecksumErrors is the number of segments dropped due to bad checksums. ChecksumErrors *StatCounter + + // FailedPortReservations is the number of times TCP failed to reserve + // a port. + FailedPortReservations *StatCounter } // UDPStats collects UDP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 3cc8c36f1..d4f7bb5ff 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -9,11 +9,14 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -78,6 +81,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", @@ -101,6 +105,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", @@ -123,6 +128,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index d10ae05c2..dbd279c94 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,11 +21,14 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -312,3 +315,194 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMulticastForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ttl = 64 + ) + + var ( + ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") + ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + + ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") + ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") + ) + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) + } + + v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) + } + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + expectForward bool + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 link-local multicast destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local source", + srcAddr: ipv4LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv4Addr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalUnicastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 non-link-local unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 non-link-local multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 link-local multicast destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local source", + srcAddr: ipv6LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv6Addr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalUnicastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 non-link-local unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 non-link-local multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + p, ok := e2.Read() + if ok != test.expectForward { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward) + } + + if test.expectForward { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 1cfd854a0..c61d4e788 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -347,7 +347,7 @@ type channelEndpointWithoutWritePacket struct { t *testing.T } -func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { +func (c *channelEndpointWithoutWritePacket) WritePacket(stack.RouteInfo, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) tcpip.Error { c.t.Error("unexpectedly called WritePacket; all writes should go through WritePackets") return &tcpip.ErrNotSupported{} } @@ -627,7 +627,7 @@ func TestIPTableWritePackets(t *testing.T) { pkts := test.genPacket(r) pktsLen := pkts.Len() - if n, err := r.WritePackets(nil /* gso */, pkts, stack.NetworkHeaderParams{ + if n, err := r.WritePackets(pkts, stack.NetworkHeaderParams{ Protocol: header.UDPProtocolNumber, TTL: 64, }); err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index d39809e1c..c657714ba 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -687,10 +687,10 @@ func TestWritePacketsLinkResolution(t *testing.T) { TOS: stack.DefaultTOS, } - if n, err := r.WritePackets(nil /* gso */, pkts, params); err != nil { - t.Fatalf("r.WritePackets(nil, %#v, _): %s", params, err) + if n, err := r.WritePackets(pkts, params); err != nil { + t.Fatalf("r.WritePackets(_, %#v): %s", params, err) } else if want := pkts.Len(); want != n { - t.Fatalf("got r.WritePackets(nil, %#v, _) = %d, want = %d", n, params, want) + t.Fatalf("got r.WritePackets(_, %#v) = %d, want = %d", params, n, want) } var writer bytes.Buffer diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 2c538a43e..3df1bbd68 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -314,11 +315,11 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { TOS: stack.DefaultTOS, } data := buffer.View([]byte{1, 2, 3, 4}) - if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + if err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: data.ToVectorisedView(), })); err != nil { - t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) + t.Fatalf("r.WritePacket(%#v, _): %s", params, err) } // Removing the address should make the endpoint invalid. @@ -326,12 +327,12 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) } { - err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + err := r.WritePacket(params, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: data.ToVectorisedView(), })) if _, ok := err.(*tcpip.ErrInvalidEndpointState); !ok { - t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) + t.Fatalf("got r.WritePacket(%#v, _) = %s, want = %s", params, err, &tcpip.ErrInvalidEndpointState{}) } } } @@ -510,25 +511,25 @@ func TestExternalLoopbackTraffic(t *testing.T) { nicID1 = 1 nicID2 = 2 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") - numPackets = 1 + ttl = 64 ) + ipv4Loopback := testutil.MustParse4("127.0.0.1") loopbackSourcedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl) } loopbackSourcedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl) } loopbackDestinedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl) } loopbackDestinedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl) } invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index c6a9c2393..2d0a6e6a7 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -43,12 +44,15 @@ const ( // to a multicast or broadcast address uses a unicast source address for the // reply. func TestPingMulticastBroadcast(t *testing.T) { - const nicID = 1 + const ( + nicID = 1 + ttl = 64 + ) tests := []struct { name string protoNum tcpip.NetworkProtocolNumber - rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8) srcAddr tcpip.Address dstAddr tcpip.Address expectedSrc tcpip.Address @@ -136,7 +140,7 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - test.rxICMP(e, test.srcAddr, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr, ttl) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") @@ -435,10 +439,10 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { // interested endpoints. func TestReuseAddrAndBroadcast(t *testing.T) { const ( - nicID = 1 - localPort = 9000 - loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + nicID = 1 + localPort = 9000 ) + loopbackBroadcast := testutil.MustParse4("127.255.255.255") tests := []struct { name string diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 78244f4eb..ac3c703d4 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -40,13 +41,13 @@ import ( // This tests that a local route is created and packets do not leave the stack. func TestLocalPing(t *testing.T) { const ( - nicID = 1 - ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + nicID = 1 // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo // request/reply packets. icmpDataOffset = 8 ) + ipv4Loopback := testutil.MustParse4("127.0.0.1") channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index d1c9f3a94..8fd9be32b 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -48,10 +48,6 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) -const ( - ttl = 255 -) - // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. // RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { // RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) diff --git a/pkg/tcpip/testutil/BUILD b/pkg/tcpip/testutil/BUILD new file mode 100644 index 000000000..472545a5d --- /dev/null +++ b/pkg/tcpip/testutil/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + testonly = True, + srcs = ["testutil.go"], + visibility = ["//visibility:public"], + deps = ["//pkg/tcpip"], +) + +go_test( + name = "testutil_test", + srcs = ["testutil_test.go"], + library = ":testutil", + deps = ["//pkg/tcpip"], +) diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go new file mode 100644 index 000000000..1aaed590f --- /dev/null +++ b/pkg/tcpip/testutil/testutil.go @@ -0,0 +1,43 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil provides helper functions for netstack unit tests. +package testutil + +import ( + "fmt" + "net" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address. +// Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes. +func MustParse4(addr string) tcpip.Address { + ip := net.ParseIP(addr).To4() + if ip == nil { + panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr)) + } + return tcpip.Address(ip) +} + +// MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing +// an IPv4 address will yield an IPv4-mapped IPv6 address. +func MustParse6(addr string) tcpip.Address { + ip := net.ParseIP(addr).To16() + if ip == nil { + panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr)) + } + return tcpip.Address(ip) +} diff --git a/pkg/tcpip/testutil/testutil_test.go b/pkg/tcpip/testutil/testutil_test.go new file mode 100644 index 000000000..6aad9585d --- /dev/null +++ b/pkg/tcpip/testutil/testutil_test.go @@ -0,0 +1,103 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// Who tests the testutils? + +func TestMustParse4(t *testing.T) { + tcs := []struct { + str string + addr tcpip.Address + shouldPanic bool + }{ + { + str: "127.0.0.1", + addr: "\x7f\x00\x00\x01", + }, { + str: "", + shouldPanic: true, + }, { + str: "fe80::1", + shouldPanic: true, + }, { + // In an ideal world this panics too, but net.IP + // doesn't distinguish between IPv4 and IPv4-mapped + // addresses. + str: "::ffff:0.0.0.1", + addr: "\x00\x00\x00\x01", + }, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("panic expected, but did not occur") + } + }() + } + if got := MustParse4(tc.str); got != tc.addr { + t.Errorf("got MustParse4(%s) = %s, want = %s", tc.str, got, tc.addr) + } + }) + } +} + +func TestMustParse6(t *testing.T) { + tcs := []struct { + str string + addr tcpip.Address + shouldPanic bool + }{ + { + // In an ideal world this panics too, but net.IP + // doesn't distinguish between IPv4 and IPv4-mapped + // addresses. + str: "127.0.0.1", + addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x7f\x00\x00\x01", + }, { + str: "", + shouldPanic: true, + }, { + str: "fe80::1", + addr: "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + }, { + str: "::ffff:0.0.0.1", + addr: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01", + }, + } + + for _, tc := range tcs { + t.Run(tc.str, func(t *testing.T) { + if tc.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("panic expected, but did not occur") + } + }() + } + if got := MustParse6(tc.str); got != tc.addr { + t.Errorf("got MustParse6(%s) = %s, want = %s", tc.str, got, tc.addr) + } + }) + } +} diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go deleted file mode 100644 index eeea97b12..000000000 --- a/pkg/tcpip/time_unsafe.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build go1.9 -// +build !go1.18 - -// Check go:linkname function signatures when updating Go version. - -package tcpip - -import ( - "time" // Used with go:linkname. - _ "unsafe" // Required for go:linkname. -) - -// StdClock implements Clock with the time package. -// -// +stateify savable -type StdClock struct{} - -var _ Clock = (*StdClock)(nil) - -//go:linkname now time.now -func now() (sec int64, nsec int32, mono int64) - -// NowNanoseconds implements Clock.NowNanoseconds. -func (*StdClock) NowNanoseconds() int64 { - sec, nsec, _ := now() - return sec*1e9 + int64(nsec) -} - -// NowMonotonic implements Clock.NowMonotonic. -func (*StdClock) NowMonotonic() int64 { - _, _, mono := now() - return mono -} - -// AfterFunc implements Clock.AfterFunc. -func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { - return &stdTimer{ - t: time.AfterFunc(d, f), - } -} - -type stdTimer struct { - t *time.Timer -} - -var _ Timer = (*stdTimer)(nil) - -// Stop implements Timer.Stop. -func (st *stdTimer) Stop() bool { - return st.t.Stop() -} - -// Reset implements Timer.Reset. -func (st *stdTimer) Reset(d time.Duration) { - st.t.Reset(d) -} - -// NewStdTimer returns a Timer implemented with the time package. -func NewStdTimer(t *time.Timer) Timer { - return &stdTimer{t: t} -} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index a82384c49..1633d0aeb 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -29,7 +29,7 @@ const ( ) func TestJobReschedule(t *testing.T) { - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var wg sync.WaitGroup var lock sync.Mutex @@ -43,7 +43,7 @@ func TestJobReschedule(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { wg.Done() }) job.Schedule(shortDuration) @@ -56,11 +56,11 @@ func TestJobReschedule(t *testing.T) { func TestJobExecution(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) @@ -83,11 +83,11 @@ func TestJobExecution(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(middleDuration) lock.Lock() @@ -114,12 +114,12 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -151,13 +151,13 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -174,12 +174,12 @@ func TestJobImmediatelyCancel(t *testing.T) { func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) job.Cancel() lock.Unlock() @@ -206,12 +206,12 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. @@ -239,12 +239,12 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - var clock tcpip.StdClock + clock := tcpip.NewStdClock() var lock sync.Mutex ch := make(chan struct{}) lock.Lock() - job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job := tcpip.NewJob(clock, &lock, func() { ch <- struct{}{} }) job.Schedule(shortDuration) for i := 0; i < 10; i++ { job.Cancel() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 50991c3c0..8afde7fca 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -63,12 +63,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList icmpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -84,6 +83,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { @@ -93,19 +96,23 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - state: stateInitial, - uniqueID: s.UniqueID(), + waiterQueue: waiterQueue, + state: stateInitial, + uniqueID: s.UniqueID(), } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption if err := s.Option(&ss); err == nil { ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } + var rs tcpip.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) + } return ep, nil } @@ -371,12 +378,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.rcvMu.Lock() v := int(e.ttl) @@ -430,7 +431,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() return err } @@ -477,7 +478,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() } @@ -746,8 +747,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -755,8 +756,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(b/129292233): Determine if len(h) check is still needed after early - // parsing. + // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed + // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -774,7 +775,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -843,3 +845,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index a3c6db5a8..28a56a2d5 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -36,40 +36,21 @@ func (p *icmpPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) if e.state != stateBound && e.state != stateConnected { return diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 52ed9560c..496eca581 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -72,11 +72,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList packetList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -91,6 +90,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a new packet endpoint. @@ -100,12 +103,12 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, }, - cooked: cooked, - netProto: netProto, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, } - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) + ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -113,9 +116,9 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - ep.rcvBufSizeMax = rs.Default + ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { @@ -316,28 +319,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := ep.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - ep.rcvMu.Lock() - ep.rcvBufSizeMax = v - ep.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } func (ep *endpoint) LastError() tcpip.Error { @@ -374,12 +356,6 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { ep.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - ep.rcvMu.Lock() - v := ep.rcvBufSizeMax - ep.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -397,7 +373,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, return } - if ep.rcvBufSize >= ep.rcvBufSizeMax { + rcvBufSize := ep.ops.GetReceiveBufferSize() + if ep.frozen || ep.rcvBufSize >= int(rcvBufSize) { ep.rcvMu.Unlock() ep.stack.Stats().DroppedPackets.Increment() ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -513,3 +490,18 @@ func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { return &ep.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (ep *endpoint) freeze() { + ep.mu.Lock() + ep.frozen = true + ep.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (ep *endpoint) thaw() { + ep.mu.Lock() + ep.frozen = false + ep.mu.Unlock() +} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index ece662c0d..5bd860d20 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -38,33 +38,14 @@ func (p *packet) loadData(data buffer.VectorisedView) { // beforeSave is invoked by stateify. func (ep *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - ep.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - ep.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max + ep.freeze() } // afterLoad is invoked by stateify. func (ep *endpoint) afterLoad() { + ep.thaw() ep.stack = stack.StackFromEnv - ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits) + ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e27a249cd..bcec3d2e7 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,7 +26,6 @@ package raw import ( - "fmt" "io" "gvisor.dev/gvisor/pkg/sync" @@ -69,11 +68,10 @@ type endpoint struct { // The following fields are used to manage the receive queue and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvList rawPacketList - rcvBufSize int - rcvBufSizeMax int `state:".(int)"` - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by mu. mu sync.RWMutex `state:"nosave"` @@ -89,6 +87,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // NewEndpoint returns a raw endpoint for the given protocols. @@ -107,13 +109,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt NetProto: netProto, TransProto: transProto, }, - waiterQueue: waiterQueue, - rcvBufSizeMax: 32 * 1024, - associated: associated, + waiterQueue: waiterQueue, + associated: associated, } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -121,16 +123,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } // Unassociated endpoints are write-only and users call Write() with IP // headers included. Because they're write-only, We don't need to // register with the stack. if !associated { - e.rcvBufSizeMax = 0 + e.ops.SetReceiveBufferSize(0, false) e.waiterQueue = nil return e, nil } @@ -352,7 +354,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp Data: buffer.View(payloadBytes).ToVectorisedView(), }) pkt.Owner = owner - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := route.WritePacket(stack.NetworkHeaderParams{ Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS, @@ -511,30 +513,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) - } - if v > rs.Max { - v = rs.Max - } - if v < rs.Min { - v = rs.Min - } - e.rcvMu.Lock() - e.rcvBufSizeMax = v - e.rcvMu.Unlock() - return nil - - default: - return &tcpip.ErrUnknownProtocolOption{} - } + return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -555,12 +535,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - default: return -1, &tcpip.ErrUnknownProtocolOption{} } @@ -587,7 +561,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() @@ -690,3 +665,18 @@ func (*endpoint) LastError() tcpip.Error { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 263ec5146..5d6f2709c 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -36,40 +36,21 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { p.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after saveRcvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) // If the endpoint is connected, re-connect. if e.connected { diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index a69d6624d..48417f192 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -34,14 +34,12 @@ go_library( "connect.go", "connect_unsafe.go", "cubic.go", - "cubic_state.go", "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", "protocol.go", "rack.go", - "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", @@ -107,6 +105,7 @@ go_test( "//pkg/tcpip/network/ipv6", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/tcp/testing/context", "//pkg/test/testutil", "//pkg/waiter", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 025b134e2..d4bd4e80e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -23,7 +23,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -51,11 +50,6 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 - - // SynRcvdCountThreshold is the default global maximum number of - // connections that are allowed to be in SYN-RCVD state before TCP - // starts using SYN cookies to accept connections. - SynRcvdCountThreshold uint64 = 1000 ) var ( @@ -80,9 +74,6 @@ func encodeMSS(mss uint16) uint32 { type listenContext struct { stack *stack.Stack - // synRcvdCount is a reference to the stack level synRcvdCount. - synRcvdCount *synRcvdCounter - // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. rcvWnd seqnum.Size @@ -138,14 +129,12 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } - p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) - if !ok { - panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) - } - l.synRcvdCount = p.SynRcvdCounter() - rand.Read(l.nonce[0][:]) - rand.Read(l.nonce[1][:]) + for i := range l.nonce { + if _, err := io.ReadFull(stk.SecureRNG(), l.nonce[i][:]); err != nil { + panic(err) + } + } return l } @@ -163,14 +152,17 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // Feed everything to the hasher. l.hasherMu.Lock() l.hasher.Reset() + + // Per hash.Hash.Writer: + // + // It never returns an error. l.hasher.Write(payload[:]) l.hasher.Write(l.nonce[nonceIndex][:]) - io.WriteString(l.hasher, string(id.LocalAddress)) - io.WriteString(l.hasher, string(id.RemoteAddress)) + l.hasher.Write([]byte(id.LocalAddress)) + l.hasher.Write([]byte(id.RemoteAddress)) // Finalize the calculation of the hash and return the first 4 bytes. - h := make([]byte, 0, sha1.Size) - h = l.hasher.Sum(h) + h := l.hasher.Sum(nil) l.hasherMu.Unlock() return binary.BigEndian.Uint32(h[:]) @@ -199,9 +191,17 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } +func (l *listenContext) useSynCookies() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) +} + // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -215,11 +215,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n := newEndpoint(l.stack, netProto, queue) n.ops.SetV6Only(l.v6Only) - n.ID = s.id + n.TransportEndpointInfo.ID = s.id n.boundNICID = s.nicID n.route = route n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} - n.rcvBufSize = int(l.rcvWnd) + n.ops.SetReceiveBufferSize(int64(l.rcvWnd), false /* notify */) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -231,7 +231,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // Bootstrap the auto tuning algorithm. Starting at zero will result in // a large step function on the first window adjustment causing the // window to grow to a really large value. - n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + n.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = n.initialReceiveWindow() return n, nil } @@ -248,7 +248,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err } @@ -290,7 +290,14 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q } // Register new endpoint so that packets are routed to it. - if err := ep.stack.RegisterTransportEndpoint(ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + if err := ep.stack.RegisterTransportEndpoint( + ep.effectiveNetProtos, + ProtocolNumber, + ep.TransportEndpointInfo.ID, + ep, + ep.boundPortFlags, + ep.boundBindToDevice, + ); err != nil { ep.mu.Unlock() ep.Close() @@ -307,6 +314,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Initialize and start the handshake. h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + h.listenEP = l.listenEP h.start() return h, nil } @@ -334,14 +342,14 @@ func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, func (l *listenContext) addPendingEndpoint(n *endpoint) { l.pendingMu.Lock() - l.pendingEndpoints[n.ID] = n + l.pendingEndpoints[n.TransportEndpointInfo.ID] = n l.pending.Add(1) l.pendingMu.Unlock() } func (l *listenContext) removePendingEndpoint(n *endpoint) { l.pendingMu.Lock() - delete(l.pendingEndpoints, n.ID) + delete(l.pendingEndpoints, n.TransportEndpointInfo.ID) l.pending.Done() l.pendingMu.Unlock() } @@ -382,39 +390,46 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window // scaling. - e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + e.rcv.RcvWndScale = e.h.effectiveRcvWndScale() // Clean up handshake state stored in the endpoint so that it can be GCed. e.h = nil } // deliverAccepted delivers the newly-accepted endpoint to the listener. If the -// endpoint has transitioned out of the listen state (acceptedChan is nil), -// the new endpoint is closed instead. +// listener has transitioned out of the listen state (accepted is the zero +// value), the new endpoint is reset instead. func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() defer e.pendingAccepted.Done() - e.acceptMu.Lock() - for { - if e.acceptedChan == nil { - e.acceptMu.Unlock() - n.notifyProtocolGoroutine(notifyReset) - return - } - select { - case e.acceptedChan <- n: + // Drop the lock before notifying to avoid deadlock in user-specified + // callbacks. + delivered := func() bool { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + for { + if e.accepted == (accepted{}) { + return false + } + if e.accepted.endpoints.Len() == e.accepted.cap { + e.acceptCond.Wait() + continue + } + + e.accepted.endpoints.PushBack(n) if !withSynCookie { atomic.AddInt32(&e.synRcvdCount, -1) } - e.acceptMu.Unlock() - e.waiterQueue.Notify(waiter.ReadableEvents) - return - default: - e.acceptCond.Wait() + return true } + }() + if delivered { + e.waiterQueue.Notify(waiter.ReadableEvents) + } else { + n.notifyProtocolGoroutine(notifyReset) } } @@ -436,17 +451,21 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // * propagateInheritableOptionsLocked has been called. // * e.mu is held. func (e *endpoint) reserveTupleLocked() bool { - dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + dest := tcpip.FullAddress{ + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, + } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: dest, } if !e.stack.ReserveTuple(portRes) { + e.stack.Stats().TCP.FailedPortReservations.Increment() return false } @@ -485,7 +504,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - defer ctx.synRcvdCount.dec() if err := h.complete(); err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -497,24 +515,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() // S/R-SAFE: synRcvdCount is the barrier. + }() return nil } -func (e *endpoint) incSynRcvdCount() bool { +func (e *endpoint) synRcvdBacklogFull() bool { e.acceptMu.Lock() - canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) + acceptedCap := e.accepted.cap e.acceptMu.Unlock() - if canInc { - atomic.AddInt32(&e.synRcvdCount, 1) - } - return canInc + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + // + // We maintain an equality check here as the synRcvdCount is incremented + // and compared only from a single listener context and the capacity of + // the accepted queue can only increase by a new listen call. + return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) + full := e.accepted != (accepted{}) && e.accepted.endpoints.Len() == e.accepted.cap e.acceptMu.Unlock() return full } @@ -524,9 +547,9 @@ func (e *endpoint) acceptQueueIsFull() bool { // // Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // @@ -538,69 +561,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err switch { case s.flags == header.TCPFlagSyn: - opts := parseSynSegmentOptions(s) - if ctx.synRcvdCount.inc() { - // Only handle the syn if the following conditions hold - // - accept queue is not full. - // - number of connections in synRcvd state is less than the - // backlog. - if !e.acceptQueueIsFull() && e.incSynRcvdCount() { - s.incRef() - _ = e.handleSynSegment(ctx, s, &opts) - return nil - } - ctx.synRcvdCount.dec() + if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return nil - } else { - // If cookies are in use but the endpoint accept queue - // is full then drop the syn. - if e.acceptQueueIsFull() { - e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + } - route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) - if err != nil { - return err - } - defer route.Release() + opts := parseSynSegmentOptions(s) + if !ctx.useSynCookies() { + s.incRef() + atomic.AddInt32(&e.synRcvdCount, 1) + return e.handleSynSegment(ctx, s, &opts) + } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() - // Send SYN without window scaling because we currently - // don't encode this information in the cookie. - // - // Enable Timestamp option if the original syn did have - // the timestamp option specified. - // - // Use the user supplied MSS on the listening socket for - // new connections, if available. - synOpts := header.TCPSynOptions{ - WS: -1, - TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), - TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, route), - } - fields := tcpFields{ - id: s.id, - ttl: e.ttl, - tos: e.sendTOS, - flags: header.TCPFlagSyn | header.TCPFlagAck, - seq: cookie, - ack: s.sequenceNumber + 1, - rcvWnd: ctx.rcvWnd, - } - if err := e.sendSynTCP(route, fields, synOpts); err != nil { - return err - } - e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() - return nil + // Send SYN without window scaling because we currently + // don't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSEcr: opts.TSVal, + MSS: calculateAdvertisedMSS(e.userMSS, route), + } + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + fields := tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + } + if err := e.sendSynTCP(route, fields, synOpts); err != nil { + return err } + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { @@ -615,25 +624,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - if !ctx.synRcvdCount.synCookiesInUse() { - // When not using SYN cookies, as per RFC 793, section 3.9, page 64: - // Any acknowledgment is bad if it arrives on a connection still in - // the LISTEN state. An acceptable reset segment should be formed - // for any arriving ACK-bearing segment. The RST should be - // formatted as follows: - // - // <SEQ=SEG.ACK><CTL=RST> - // - // Send a reset as this is an ACK for which there is no - // half open connections and we are not using cookies - // yet. - // - // The only time we should reach here when a connection - // was opened and closed really quickly and a delayed - // ACK was received from the sender. - return replyWithReset(e.stack, s, e.sendTOS, e.ttl) - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -651,7 +641,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return nil + + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -672,7 +678,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, rcvdSynOptions, &waiter.Queue{}) if err != nil { return err } @@ -693,7 +699,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint( + n.effectiveNetProtos, + ProtocolNumber, + n.TransportEndpointInfo.ID, + n, + n.boundPortFlags, + n.boundBindToDevice, + ); err != nil { n.mu.Unlock() n.Close() @@ -708,7 +721,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // endpoint as the Timestamp was already // randomly offset when the original SYN-ACK was // sent above. - n.tsOffset = 0 + n.TSOffset = 0 // Switch state to connected. n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a9e978cf6..5e03e7715 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -65,11 +65,12 @@ const ( // NOTE: handshake.ep.mu is held during handshake processing. It is released if // we are going to block and reacquired when we start processing an event. type handshake struct { - ep *endpoint - state handshakeState - active bool - flags header.TCPFlags - ackNum seqnum.Value + ep *endpoint + listenEP *endpoint + state handshakeState + active bool + flags header.TCPFlags + ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. iss seqnum.Value @@ -155,7 +156,7 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the @@ -301,7 +302,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { ttl = h.ep.route.DefaultTTL() } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -357,14 +358,14 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, - TS: h.ep.sendTSOk, + TS: h.ep.SendTSOk, TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), - SACKPermitted: h.ep.sackPermitted, + SACKPermitted: h.ep.SACKPermitted, MSS: h.ep.amss, } h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -389,13 +390,22 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. - if h.ep.sendTSOk && !s.parsedOptions.TS { + if h.ep.SendTSOk && !s.parsedOptions.TS { h.ep.stack.Stats().DroppedPackets.Increment() return nil } + // Drop the ACK if the accept queue is full. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523 + // We could abort the connection as well with a tunable as in + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788 + if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() { + listenEP.stack.Stats().DroppedPackets.Increment() + return nil + } + // Update timestamp if required. See RFC7323, section-4.3. - if h.ep.sendTSOk && s.parsedOptions.TS { + if h.ep.SendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted @@ -485,8 +495,8 @@ func (h *handshake) start() { // start() is also called in a listen context so we want to make sure we only // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { - synOpts.TS = h.ep.sendTSOk - synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + synOpts.TS = h.ep.SendTSOk + synOpts.SACKPermitted = h.ep.SACKPermitted && bool(sackEnabled) if h.sndWndScale < 0 { // Disable window scaling if the peer did not send us // the window scaling option. @@ -496,7 +506,7 @@ func (h *handshake) start() { h.sendSYNOpts = synOpts h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -544,7 +554,7 @@ func (h *handshake) complete() tcpip.Error { // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ - id: h.ep.ID, + id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, flags: h.flags, @@ -576,8 +586,14 @@ func (h *handshake) complete() tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + // Check for any ICMP errors notified to us. if n¬ifyError != 0 { - return h.ep.lastErrorLocked() + if err := h.ep.lastErrorLocked(); err != nil { + return err + } + // Flag the handshake failure as aborted if the lastError is + // cleared because of a socket layer call. + return &tcpip.ErrConnectionAborted{} } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -711,14 +727,14 @@ type tcpFields struct { func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) tcpip.Error { tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, stack.GSO{}); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } putOptions(tf.opts) return nil } -func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) tcpip.Error { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO) tcpip.Error { tf.txHash = e.txHash if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() @@ -728,7 +744,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV return nil } -func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso stack.GSO) { optLen := len(tf.opts) tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) pkt.TransportProtocolNumber = header.TCPProtocolNumber @@ -745,7 +761,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. - if gso != nil && gso.NeedsCsum { + if gso.Type != stack.GSONone && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We // calculate a checksum of the pseudo-header and save it in the // TCP header, then the kernel calculate a checksum of the @@ -757,7 +773,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error { // We need to shallow clone the VectorisedView here as ReadToView will // split the VectorisedView and Trim underlying views as it splits. Not // doing the clone here will cause the underlying views of data itself @@ -789,13 +805,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Data().ReadFromVV(&data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkt.GSOOptions = gso pkts.PushBack(pkt) } if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } - sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) + sent, err := r.WritePackets(pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -805,13 +822,13 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso stack.GSO, owner tcpip.PacketOwner) tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > math.MaxUint16 { tf.rcvWnd = math.MaxUint16 } - if r.Loop()&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + if r.Loop()&stack.PacketLoop == 0 && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { return sendTCPBatch(r, tf, data, gso, owner) } @@ -819,6 +836,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, Data: data, }) + pkt.GSOOptions = gso pkt.Hash = tf.txHash pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) @@ -826,7 +844,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { + if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -845,7 +863,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // N.B. the ordering here matches the ordering used by Linux internally // and described in the raw makeOptions function. We don't include // unnecessary cases here (post connection.) - if e.sendTSOk { + if e.SendTSOk { // Embed the timestamp if timestamp has been enabled. // // We only use the lower 32 bits of the unix time in @@ -862,7 +880,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } - if e.sackPermitted && len(sackBlocks) > 0 { + if e.SACKPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) @@ -884,7 +902,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } options := e.makeOptions(sackBlocks) err := e.sendTCP(e.route, tcpFields{ - id: e.ID, + id: e.TransportEndpointInfo.ID, ttl: e.ttl, tos: e.sendTOS, flags: flags, @@ -898,9 +916,9 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se } func (e *endpoint) handleWrite() { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() next := e.drainSendQueueLocked() - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.sendData(next) } @@ -909,10 +927,10 @@ func (e *endpoint) handleWrite() { // // Precondition: e.sndBufMu must be locked. func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueue.Front() + first := e.sndQueueInfo.sndQueue.Front() if first != nil { - e.snd.writeList.PushBackList(&e.sndQueue) - e.sndBufInQueue = 0 + e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) + e.sndQueueInfo.SndBufInQueue = 0 } return first } @@ -936,7 +954,7 @@ func (e *endpoint) handleClose() { e.handleWrite() // Mark send side as closed. - e.snd.closed = true + e.snd.Closed = true } // resetConnectionLocked puts the endpoint in an error state with the given @@ -958,12 +976,12 @@ func (e *endpoint) resetConnectionLocked(err tcpip.Error) { // // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more // information. - sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + sndWndEnd := e.snd.SndUna.Add(e.snd.SndWnd) resetSeqNum := sndWndEnd - if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { - resetSeqNum = e.snd.sndNxt + if !sndWndEnd.LessThan(e.snd.SndNxt) || e.snd.SndNxt.Size(sndWndEnd) < (1<<e.snd.SndWndScale) { + resetSeqNum = e.snd.SndNxt } - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.RcvNxt, 0) } } @@ -989,13 +1007,13 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // (indicated by a negative send window scale). e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) // Bootstrap the auto tuning algorithm. Starting at zero will // result in a really large receive window after the first auto // tuning adjustment. - e.rcvAutoParams.prevCopied = int(h.rcvWnd) - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = int(h.rcvWnd) + e.rcvQueueInfo.rcvQueueMu.Unlock() e.setEndpointState(StateEstablished) } @@ -1026,10 +1044,15 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.TransportEndpointInfo.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) + ep = e.stack.FindTransportEndpoint( + header.IPv4ProtocolNumber, + e.TransProto, + e.TransportEndpointInfo.ID, + s.nicID, + ) } if ep == nil { replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) @@ -1108,7 +1131,9 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { } // handleSegments processes all inbound segments. -func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { if e.EndpointState().closed() { @@ -1120,7 +1145,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { break } - cont, err := e.handleSegment(s) + cont, err := e.handleSegmentLocked(s) s.decRef() if err != nil { return err @@ -1138,7 +1163,7 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { } // Send an ACK for all processed packets if needed. - if e.rcv.rcvNxt != e.snd.maxSentAck { + if e.rcv.RcvNxt != e.snd.MaxSentAck { e.snd.sendAck() } @@ -1147,18 +1172,21 @@ func (e *endpoint) handleSegments(fastPath bool) tcpip.Error { return nil } -func (e *endpoint) probeSegment() { - if e.probe != nil { - e.probe(e.completeState()) +// Precondition: e.mu must be held. +func (e *endpoint) probeSegmentLocked() { + if fn := e.probe; fn != nil { + fn(e.completeStateLocked()) } } // handleSegment handles a given segment and notifies the worker goroutine if // if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { +// +// Precondition: e.mu must be held. +func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) { // Invoke the tcp probe if installed. The tcp probe function will update // the TCPEndpointState after the segment is processed. - defer e.probeSegment() + defer e.probeSegmentLocked() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { @@ -1191,7 +1219,7 @@ func (e *endpoint) handleSegment(s *segment) (cont bool, err tcpip.Error) { } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. - s.window <<= e.snd.sndWndScale + s.window <<= e.snd.SndWndScale // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment @@ -1255,7 +1283,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // seg.seq = snd.nxt-1. e.keepalive.unacked++ e.keepalive.Unlock() - e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.SndNxt-1) e.resetKeepaliveTimer(false) return nil } @@ -1269,7 +1297,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.SndUna != e.snd.SndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1340,8 +1368,24 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Reaching this point means that we successfully completed the 3-way - // handshake with our peer. - // + // handshake with our peer. The current endpoint state could be any state + // post ESTABLISHED, including CLOSED or ERROR if the endpoint processes a + // RST from the peer via the dispatcher fast path, before the loop is + // started. + if s := e.EndpointState(); !s.connected() { + switch s { + case StateClose, StateError: + // If the endpoint is in CLOSED/ERROR state, sender state has to be + // initialized if the endpoint was previously established. + if e.snd != nil { + break + } + fallthrough + default: + panic("endpoint was not established, current state " + s.String()) + } + } + // Completing the 3-way handshake is an indication that the route is valid // and the remote is reachable as the only way we can complete a handshake // is if our SYN reached the remote and their ACK reached us. @@ -1362,14 +1406,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ f func() tcpip.Error }{ { - w: &e.sndWaker, + w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { e.handleWrite() return nil }, }, { - w: &e.sndCloseWaker, + w: &e.sndQueueInfo.sndCloseWaker, f: func() tcpip.Error { e.handleClose() return nil @@ -1403,7 +1447,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.newSegmentWaker, f: func() tcpip.Error { - return e.handleSegments(false /* fastPath */) + return e.handleSegmentsLocked(false /* fastPath */) }, }, { @@ -1419,11 +1463,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if n¬ifyMTUChanged != 0 { - e.sndBufMu.Lock() - count := e.packetTooBigCount - e.packetTooBigCount = 0 - mtu := e.sndMTU - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + count := e.sndQueueInfo.PacketTooBigCount + e.sndQueueInfo.PacketTooBigCount = 0 + mtu := e.sndQueueInfo.SndMTU + e.sndQueueInfo.sndQueueMu.Unlock() e.snd.updateMaxPayloadSize(mtu, count) } @@ -1453,7 +1497,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(false /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(false /* fastPath */); err != nil { return err } } @@ -1504,11 +1548,11 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.newSegmentWaker.Assert() } - e.rcvListMu.Lock() - if !e.rcvList.Empty() { + e.rcvQueueInfo.rcvQueueMu.Lock() + if !e.rcvQueueInfo.rcvQueue.Empty() { e.waiterQueue.Notify(waiter.ReadableEvents) } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 1975f1a44..962f1d687 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -17,6 +17,8 @@ package tcp import ( "math" "time" + + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // cubicState stores the variables related to TCP CUBIC congestion @@ -25,47 +27,12 @@ import ( // See: https://tools.ietf.org/html/rfc8312. // +stateify savable type cubicState struct { - // wLastMax is the previous wMax value. - wLastMax float64 - - // wMax is the value of the congestion window at the - // time of last congestion event. - wMax float64 - - // t denotes the time when the current congestion avoidance - // was entered. - t time.Time `state:".(unixTime)"` + stack.TCPCubicState // numCongestionEvents tracks the number of congestion events since last // RTO. numCongestionEvents int - // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as - // per RFC. - c float64 - - // k is the time period that the above function takes to increase the - // current window size to W_max if there are no further congestion - // events and is calculated using the following equation: - // - // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) - k float64 - - // beta is the CUBIC multiplication decrease factor. that is, when a - // congestion event is detected, CUBIC reduces its cwnd to - // W_cubic(0)=W_max*beta_cubic. - beta float64 - - // wC is window computed by CUBIC at time t. It's calculated using the - // formula: - // - // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) - wC float64 - - // wEst is the window computed by CUBIC at time t+RTT i.e - // W_cubic(t+RTT). - wEst float64 - s *sender } @@ -73,10 +40,12 @@ type cubicState struct { // beta and c set and t set to current time. func newCubicCC(s *sender) *cubicState { return &cubicState{ - t: time.Now(), - beta: 0.7, - c: 0.4, - s: s, + TCPCubicState: stack.TCPCubicState{ + T: time.Now(), + Beta: 0.7, + C: 0.4, + }, + s: s, } } @@ -90,10 +59,10 @@ func (c *cubicState) enterCongestionAvoidance() { // See: https://tools.ietf.org/html/rfc8312#section-4.7 & // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { - c.k = 0 - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.K = 0 + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) } } @@ -104,16 +73,16 @@ func (c *cubicState) enterCongestionAvoidance() { func (c *cubicState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := c.s.sndCwnd + packetsAcked + newcwnd := c.s.SndCwnd + packetsAcked enterCA := false - if newcwnd >= c.s.sndSsthresh { - newcwnd = c.s.sndSsthresh - c.s.sndCAAckCount = 0 + if newcwnd >= c.s.Ssthresh { + newcwnd = c.s.Ssthresh + c.s.SndCAAckCount = 0 enterCA = true } - packetsAcked -= newcwnd - c.s.sndCwnd - c.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - c.s.SndCwnd + c.s.SndCwnd = newcwnd if enterCA { c.enterCongestionAvoidance() } @@ -124,49 +93,49 @@ func (c *cubicState) updateSlowStart(packetsAcked int) int { // ACK received. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) Update(packetsAcked int) { - if c.s.sndCwnd < c.s.sndSsthresh { + if c.s.SndCwnd < c.s.Ssthresh { packetsAcked = c.updateSlowStart(packetsAcked) if packetsAcked == 0 { return } } else { c.s.rtt.Lock() - srtt := c.s.rtt.srtt + srtt := c.s.rtt.TCPRTTState.SRTT c.s.rtt.Unlock() - c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + c.s.SndCwnd = c.getCwnd(packetsAcked, c.s.SndCwnd, srtt) } } // cubicCwnd computes the CUBIC congestion window after t seconds from last // congestion event. func (c *cubicState) cubicCwnd(t float64) float64 { - return c.c*math.Pow(t, 3.0) + c.wMax + return c.C*math.Pow(t, 3.0) + c.WMax } // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.t).Seconds() + elapsed := time.Since(c.T).Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.wC = c.cubicCwnd(elapsed - c.k) + c.WC = c.cubicCwnd(elapsed - c.K) // Compute the TCP friendly estimate of the congestion window. - c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. - if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + if c.WC < c.WEst && float64(sndCwnd) < c.WEst { // TCP Friendly region of cubic. - return int(c.wEst) + return int(c.WEst) } // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.t) + srtt).Seconds() - wtRtt := c.cubicCwnd(tEst - c.k) + tEst := (time.Since(c.T) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. cwnd := float64(sndCwnd) @@ -182,9 +151,9 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.t = time.Now() - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.T = time.Now() + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() c.reduceSlowStartThreshold() @@ -193,10 +162,10 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.t = time.Now() + c.T = time.Now() c.numCongestionEvents = 0 - c.wLastMax = c.wMax - c.wMax = float64(c.s.sndCwnd) + c.WLastMax = c.WMax + c.WMax = float64(c.s.SndCwnd) c.fastConvergence() @@ -206,29 +175,29 @@ func (c *cubicState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - c.s.sndCwnd = 1 + c.s.SndCwnd = 1 } // fastConvergence implements the logic for Fast Convergence algorithm as // described in https://tools.ietf.org/html/rfc8312#section-4.6. func (c *cubicState) fastConvergence() { - if c.wMax < c.wLastMax { - c.wLastMax = c.wMax - c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + if c.WMax < c.WLastMax { + c.WLastMax = c.WMax + c.WMax = c.WMax * (1.0 + c.Beta) / 2.0 } else { - c.wLastMax = c.wMax + c.WLastMax = c.WMax } // Recompute k as wMax may have changed. - c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) + c.K = math.Cbrt(c.WMax * (1 - c.Beta) / c.C) } // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.t = time.Now() + c.T = time.Now() } // reduceSlowStartThreshold returns new SsThresh as described in // https://tools.ietf.org/html/rfc8312#section-4.7. func (c *cubicState) reduceSlowStartThreshold() { - c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) + c.s.Ssthresh = int(math.Max(float64(c.s.SndCwnd)*c.Beta, 2.0)) } diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 21162f01a..512053a04 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -116,7 +116,7 @@ func (p *processor) start(wg *sync.WaitGroup) { if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { // If the endpoint is in a connected state then we do direct delivery // to ensure low latency and avoid scheduler interactions. - switch err := ep.handleSegments(true /* fastPath */); { + switch err := ep.handleSegmentsLocked(true /* fastPath */); { case err != nil: // Send any active resets if required. ep.resetConnectionLocked(err) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f6a16f96e..f148d505d 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -37,8 +38,8 @@ func TestV4MappedConnectOnV6Only(t *testing.T) { // Start connection attempt, it must fail. err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -49,8 +50,8 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -156,8 +157,8 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network defer c.WQ.EventUnregister(&we) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Unexpected return value from Connect: %v", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -391,7 +392,7 @@ func testV4Accept(t *testing.T, c *context.Context) { defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -525,7 +526,7 @@ func TestV6AcceptOnV6(t *testing.T) { defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress _, _, err := c.EP.Accept(&addr) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -565,17 +566,15 @@ func TestV4AcceptOnV4(t *testing.T) { } func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } - const n = uint16(32) + const n = 32 // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + if err := c.EP.Listen(n); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -591,9 +590,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) { }) } - // Each of these ACK's will cause a syn-cookie based connection to be + // Each of these ACKs will cause a syn-cookie based connection to be // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { + for i := 0; i < n; i++ { b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss := seqnum.Value(tcp.SequenceNumber()) @@ -613,7 +612,7 @@ func testV4ListenClose(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) nep, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c5daba232..90edcfba6 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -15,6 +15,7 @@ package tcp import ( + "container/list" "encoding/binary" "fmt" "io" @@ -190,42 +191,6 @@ type SACKInfo struct { NumBlocks int } -// rcvBufAutoTuneParams are used to hold state variables to compute -// the auto tuned recv buffer size. -// -// +stateify savable -type rcvBufAutoTuneParams struct { - // measureTime is the time at which the current measurement - // was started. - measureTime time.Time `state:".(unixTime)"` - - // copied is the number of bytes copied out of the receive - // buffers since this measure began. - copied int - - // prevCopied is the number of bytes copied out of the receive - // buffers in the previous RTT period. - prevCopied int - - // rtt is the non-smoothed minimum RTT as measured by observing the time - // between when a byte is first acknowledged and the receipt of data - // that is at least one window beyond the sequence number that was - // acknowledged. - rtt time.Duration - - // rttMeasureSeqNumber is the highest acceptable sequence number at the - // time this RTT measurement period began. - rttMeasureSeqNumber seqnum.Value - - // rttMeasureTime is the absolute time at which the current rtt - // measurement period began. - rttMeasureTime time.Time `state:".(unixTime)"` - - // disabled is true if an explicit receive buffer is set for the - // endpoint. - disabled bool -} - // ReceiveErrors collect segment receive errors within transport layer. type ReceiveErrors struct { tcpip.ReceiveErrors @@ -246,7 +211,7 @@ type ReceiveErrors struct { ListenOverflowAckDrop tcpip.StatCounter // ZeroRcvWindowState is the number of times we advertised - // a zero receive window when rcvList is full. + // a zero receive window when rcvQueue is full. ZeroRcvWindowState tcpip.StatCounter // WantZeroWindow is the number of times we wanted to advertise a @@ -309,18 +274,45 @@ type Stats struct { // marker interface. func (*Stats) IsEndpointStats() {} -// EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. This exists to allow tcp-only state to -// be exposed. +// sndQueueInfo implements a send queue. // // +stateify savable -type EndpointInfo struct { - stack.TransportEndpointInfo +type sndQueueInfo struct { + sndQueueMu sync.Mutex `state:"nosave"` + stack.TCPSndBufState + + // sndQueue holds segments that are ready to be sent. + sndQueue segmentList `state:"wait"` + + // sndWaker is used to signal the protocol goroutine when segments are + // added to the `sndQueue`. + sndWaker sleep.Waker `state:"manual"` + + // sndCloseWaker is used to notify the protocol goroutine when the send + // side is closed. + sndCloseWaker sleep.Waker `state:"manual"` } -// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo -// marker interface. -func (*EndpointInfo) IsEndpointInfo() {} +// rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. +// +// +stateify savable +type rcvQueueInfo struct { + rcvQueueMu sync.Mutex `state:"nosave"` + stack.TCPRcvBufState + + // rcvQueue is the queue for ready-for-delivery segments. This struct's + // mutex must be held in order append segments to list. + rcvQueue segmentList `state:"wait"` +} + +// +stateify savable +type accepted struct { + // NB: this could be an endpointList, but ilist only permits endpoints to + // belong to one list at a time, and endpoints are already stored in the + // dispatcher's list. + endpoints list.List `state:".([]*endpoint)"` + cap int +} // endpoint represents a TCP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to @@ -337,9 +329,9 @@ func (*EndpointInfo) IsEndpointInfo() {} // The following three mutexes can be acquired independent of e.mu but if // acquired with e.mu then e.mu must be acquired first. // -// e.acceptMu -> protects acceptedChan. -// e.rcvListMu -> Protects the rcvList and associated fields. -// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.acceptMu -> protects accepted. +// e.rcvQueueMu -> Protects e.rcvQueue and associated fields. +// e.sndQueueMu -> Protects the e.sndQueue and associated fields. // e.lastErrorMu -> Protects the lastError field. // // LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different @@ -362,7 +354,8 @@ func (*EndpointInfo) IsEndpointInfo() {} // // +stateify savable type endpoint struct { - EndpointInfo + stack.TCPEndpointStateInner + stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the @@ -395,38 +388,23 @@ type endpoint struct { // rcvReadMu synchronizes calls to Read. // - // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // mu and rcvQueueMu are temporarily released during data copying. rcvReadMu // must be held during each read to ensure atomicity, so that multiple reads // do not interleave. // // rcvReadMu should be held before holding mu. rcvReadMu sync.Mutex `state:"nosave"` - // rcvListMu synchronizes access to rcvList. - // - // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` - - // rcvList is the queue for ready-for-delivery segments. - // - // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data - // and removing segments from list. A range of segment can be determined, then - // temporarily release mu and rcvListMu while processing the segment range. - // This allows new segments to be appended to the list while processing. - // - // rcvListMu must be held to append segments to list. - rcvList segmentList `state:"wait"` - rcvClosed bool - // rcvBufSize is the total size of the receive buffer. - rcvBufSize int - // rcvBufUsed is the actual number of payload bytes held in the receive buffer - // not counting any overheads of the segments itself. NOTE: This will always - // be strictly <= rcvMemUsed below. - rcvBufUsed int - rcvAutoParams rcvBufAutoTuneParams + // rcvQueueInfo holds the implementation of the endpoint's receive buffer. + // The data within rcvQueueInfo should only be accessed while rcvReadMu, mu, + // and rcvQueueMu are held, in that stated order. While processing the segment + // range, you can determine a range and then temporarily release mu and + // rcvQueueMu, which allows new segments to be appended to the queue while + // processing. + rcvQueueInfo rcvQueueInfo // rcvMemUsed tracks the total amount of memory in use by received segments - // held in rcvList, pendingRcvdSegments and the segment queue. This is used to + // held in rcvQueue, pendingRcvdSegments and the segment queue. This is used to // compute the window and the actual available buffer space. This is distinct // from rcvBufUsed above which is the actual number of payload bytes held in // the buffer not including any segment overheads. @@ -488,33 +466,16 @@ type endpoint struct { // also true, and they're both protected by the mutex. workerCleanup bool - // sendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1 - sendTSOk bool - - // recentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - recentTS uint32 - - // recentTSTime is the unix time when we updated recentTS last. + // recentTSTime is the unix time when we last updated + // TCPEndpointStateInner.RecentTS. recentTSTime time.Time `state:".(unixTime)"` - // tsOffset is a randomized offset added to the value of the - // TSVal field in the timestamp option. - tsOffset uint32 - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags // tcpRecovery is the loss deteoction algorithm used by TCP. tcpRecovery tcpip.TCPRecovery - // sackPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - sackPermitted bool - // sack holds TCP SACK related information for this endpoint. sack SACKInfo @@ -550,32 +511,13 @@ type endpoint struct { // this value. windowClamp uint32 - // The following fields are used to manage the send buffer. When - // segments are ready to be sent, they are added to sndQueue and the - // protocol goroutine is signaled via sndWaker. - // - // When the send side is closed, the protocol goroutine is notified via - // sndCloseWaker, and sndClosed is set to true. - sndBufMu sync.Mutex `state:"nosave"` - sndBufUsed int - sndClosed bool - sndBufInQueue seqnum.Size - sndQueue segmentList `state:"wait"` - sndWaker sleep.Waker `state:"manual"` - sndCloseWaker sleep.Waker `state:"manual"` + // sndQueueInfo contains the implementation of the endpoint's send queue. + sndQueueInfo sndQueueInfo // cc stores the name of the Congestion Control algorithm to use for // this endpoint. cc tcpip.CongestionControlOption - // The following are used when a "packet too big" control packet is - // received. They are protected by sndBufMu. They are used to - // communicate to the main protocol goroutine how many such control - // messages have been received since the last notification was processed - // and what was the smallest MTU seen. - packetTooBigCount int - sndMTU int - // newSegmentWaker is used to indicate to the protocol goroutine that // it needs to wake up and handle new segments queued to it. newSegmentWaker sleep.Waker `state:"manual"` @@ -607,33 +549,26 @@ type endpoint struct { // listener. deferAccept time.Duration - // pendingAccepted is a synchronization primitive used to track number - // of connections that are queued up to be delivered to the accepted - // channel. We use this to ensure that all goroutines blocked on writing - // to the acceptedChan below terminate before we close acceptedChan. + // pendingAccepted tracks connections queued to be accepted. It is used to + // ensure such queued connections are terminated before the accepted queue is + // marked closed (by setting its capacity to zero). pendingAccepted sync.WaitGroup `state:"nosave"` - // acceptMu protects acceptedChan. + // acceptMu protects accepted. acceptMu sync.Mutex `state:"nosave"` // acceptCond is a condition variable that can be used to block on when - // acceptedChan is full and an endpoint is ready to be delivered. - // - // This condition variable is required because just blocking on sending - // to acceptedChan does not work in cases where endpoint.Listen is - // called twice with different backlog values. In such cases the channel - // is closed and a new one created. Any pending goroutines blocking on - // the write to the channel will panic. + // accepted is full and an endpoint is ready to be delivered. // // We use this condition variable to block/unblock goroutines which // tried to deliver an endpoint but couldn't because accept backlog was // full ( See: endpoint.deliverAccepted ). acceptCond *sync.Cond `state:"nosave"` - // acceptedChan is used by a listening endpoint protocol goroutine to + // accepted is used by a listening endpoint protocol goroutine to // send newly accepted connections to the endpoint so that they can be // read by Accept() calls. - acceptedChan chan *endpoint `state:".([]*endpoint)"` + accepted accepted // The following are only used from the protocol goroutine, and // therefore don't need locks to protect them. @@ -664,7 +599,7 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - gso *stack.GSO + gso stack.GSO // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` @@ -779,7 +714,7 @@ func (e *endpoint) UnlockUser() { switch e.EndpointState() { case StateEstablished: - if err := e.handleSegments(true /* fastPath */); err != nil { + if err := e.handleSegmentsLocked(true /* fastPath */); err != nil { e.notifyProtocolGoroutine(notifyTickleWorker) } default: @@ -839,13 +774,13 @@ func (e *endpoint) EndpointState() EndpointState { // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { - e.recentTS = recentTS + e.RecentTS = recentTS e.recentTSTime = time.Now() } // recentTimestamp returns the value of the recentTS field. func (e *endpoint) recentTimestamp() uint32 { - return e.recentTS + return e.RecentTS } // keepalive is a synchronization wrapper used to appease stateify. See the @@ -865,16 +800,17 @@ type keepalive struct { func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { e := &endpoint{ stack: s, - EndpointInfo: EndpointInfo{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: header.TCPProtocolNumber, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + sndQueueInfo: sndQueueInfo{ + TCPSndBufState: stack.TCPSndBufState{ + SndMTU: int(math.MaxInt32), }, }, waiterQueue: waiterQueue, state: StateInitial, - rcvBufSize: DefaultReceiveBufferSize, - sndMTU: int(math.MaxInt32), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -886,10 +822,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetQuickAck(true) e.ops.SetSendBufferSize(DefaultSendBufferSize, false /* notify */) + e.ops.SetReceiveBufferSize(DefaultReceiveBufferSize, false /* notify */) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -898,7 +835,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - e.rcvBufSize = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } var cs tcpip.CongestionControlOption @@ -908,7 +845,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { - e.rcvAutoParams.disabled = !bool(mrb) + e.rcvQueueInfo.RcvAutoParams.Disabled = !bool(mrb) } var de tcpip.TCPDelayEnabled @@ -933,7 +870,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.tsOffset = timeStampOffset() + e.TSOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) e.keepalive.timer.init(&e.keepalive.waker) @@ -959,10 +896,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { result = mask case StateListen: - // Check if there's anything in the accepted channel. + // Check if there's anything in the accepted queue. if (mask & waiter.ReadableEvents) != 0 { e.acceptMu.Lock() - if len(e.acceptedChan) > 0 { + if e.accepted.endpoints.Len() != 0 { result |= waiter.ReadableEvents } e.acceptMu.Unlock() @@ -971,21 +908,21 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.WritableEvents) != 0 { - e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Lock() sndBufSize := e.getSendBufferSize() - if e.sndClosed || e.sndBufUsed < sndBufSize { + if e.sndQueueInfo.SndClosed || e.sndQueueInfo.SndBufUsed < sndBufSize { result |= waiter.WritableEvents } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() } // Determine if the endpoint is readable if requested. if (mask & waiter.ReadableEvents) != 0 { - e.rcvListMu.Lock() - if e.rcvBufUsed > 0 || e.rcvClosed { + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvBufUsed > 0 || e.rcvQueueInfo.RcvClosed { result |= waiter.ReadableEvents } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() } } @@ -1093,15 +1030,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1145,22 +1082,22 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptMu.Unlock() + acceptedCopy := e.accepted + e.accepted = accepted{} + e.acceptMu.Unlock() + + if acceptedCopy == (accepted{}) { return } - close(e.acceptedChan) - ch := e.acceptedChan - e.acceptedChan = nil + e.acceptCond.Broadcast() - e.acceptMu.Unlock() // Reset all connections that are waiting to be accepted. - for n := range ch { - n.notifyProtocolGoroutine(notifyReset) + for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { + n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) } // Wait for reset of all endpoints that are still waiting to be delivered to - // the now closed acceptedChan. + // the now closed accepted. e.pendingAccepted.Wait() } @@ -1176,7 +1113,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false } @@ -1184,8 +1121,8 @@ func (e *endpoint) cleanupLocked() { portRes := ports.Reservation{ Networks: e.effectiveNetProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, Flags: e.boundPortFlags, BindToDevice: e.boundBindToDevice, Dest: e.boundDest, @@ -1247,19 +1184,19 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - if e.rcvAutoParams.disabled { - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + if e.rcvQueueInfo.RcvAutoParams.Disabled { + e.rcvQueueInfo.rcvQueueMu.Unlock() return } now := time.Now() - if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { - e.rcvAutoParams.copied += copied - e.rcvListMu.Unlock() + if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { + e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied + e.rcvQueueInfo.rcvQueueMu.Unlock() return } - prevRTTCopied := e.rcvAutoParams.copied + copied - prevCopied := e.rcvAutoParams.prevCopied + prevRTTCopied := e.rcvQueueInfo.RcvAutoParams.CopiedBytes + copied + prevCopied := e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes rcvWnd := 0 if prevRTTCopied > prevCopied { // The minimal receive window based on what was copied by the app @@ -1291,24 +1228,25 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // We do not adjust downwards as that can cause the receiver to // reject valid data that might already be in flight as the // acceptable window will shrink. - if rcvWnd > e.rcvBufSize { - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = rcvWnd - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + rcvBufSize := int(e.ops.GetReceiveBufferSize()) + if rcvWnd > rcvBufSize { + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + e.ops.SetReceiveBufferSize(int64(rcvWnd), false /* notify */) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(rcvWnd)) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, rcvBufSize); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - // We only update prevCopied when we grow the buffer because in cases - // where prevCopied > prevRTTCopied the existing buffer is already big + // We only update PrevCopiedBytes when we grow the buffer because in cases + // where PrevCopiedBytes > prevRTTCopied the existing buffer is already big // enough to handle the current rate and we don't need to do any // adjustments. - e.rcvAutoParams.prevCopied = prevRTTCopied + e.rcvQueueInfo.RcvAutoParams.PrevCopiedBytes = prevRTTCopied } - e.rcvAutoParams.measureTime = now - e.rcvAutoParams.copied = 0 - e.rcvListMu.Unlock() + e.rcvQueueInfo.RcvAutoParams.MeasureTime = now + e.rcvQueueInfo.RcvAutoParams.CopiedBytes = 0 + e.rcvQueueInfo.rcvQueueMu.Unlock() } // SetOwner implements tcpip.Endpoint.SetOwner. @@ -1342,6 +1280,12 @@ func (e *endpoint) LastError() tcpip.Error { return e.lastErrorLocked() } +// LastErrorLocked reads and clears lastError with e.mu held. +// Only to be used in tests. +func (e *endpoint) LastErrorLocked() tcpip.Error { + return e.lastErrorLocked() +} + // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.LockUser() @@ -1357,7 +1301,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult defer e.rcvReadMu.Unlock() // N.B. Here we get a range of segments to be processed. It is safe to not - // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // hold rcvQueueMu when processing, since we hold rcvReadMu to ensure only we // can remove segments from the list through commitRead(). first, last, serr := e.startRead() if serr != nil { @@ -1429,10 +1373,10 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - bufUsed := e.rcvBufUsed + bufUsed := e.rcvQueueInfo.RcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { if s == StateError { if err := e.hardErrorLocked(); err != nil { @@ -1444,14 +1388,14 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { return nil, nil, &tcpip.ErrNotConnected{} } - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { + if e.rcvQueueInfo.RcvBufUsed == 0 { + if e.rcvQueueInfo.RcvClosed || !e.EndpointState().connected() { return nil, nil, &tcpip.ErrClosedForReceive{} } return nil, nil, &tcpip.ErrWouldBlock{} } - return e.rcvList.Front(), e.rcvList.Back(), nil + return e.rcvQueueInfo.rcvQueue.Front(), e.rcvQueueInfo.rcvQueue.Back(), nil } // commitRead commits a read of done bytes and returns the next non-empty @@ -1467,39 +1411,39 @@ func (e *endpoint) startRead() (first, last *segment, err tcpip.Error) { func (e *endpoint) commitRead(done int) *segment { e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() memDelta := 0 - s := e.rcvList.Front() + s := e.rcvQueueInfo.rcvQueue.Front() for s != nil && s.data.Size() == 0 { - e.rcvList.Remove(s) + e.rcvQueueInfo.rcvQueue.Remove(s) // Memory is only considered released when the whole segment has been // read. memDelta += s.segMemSize() s.decRef() - s = e.rcvList.Front() + s = e.rcvQueueInfo.rcvQueue.Front() } - e.rcvBufUsed -= done + e.rcvQueueInfo.RcvBufUsed -= done if memDelta > 0 { // If the window was small before this read and if the read freed up // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta, int(e.ops.GetReceiveBufferSize())); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } } - return e.rcvList.Front() + return e.rcvQueueInfo.rcvQueue.Front() } // isEndpointWritableLocked checks if a given endpoint is writable // and also returns the number of bytes that can be written at this // moment. If the endpoint is not writable then it returns an error // indicating the reason why it's not writable. -// Caller must hold e.mu and e.sndBufMu +// Caller must hold e.mu and e.sndQueueMu func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { @@ -1519,12 +1463,12 @@ func (e *endpoint) isEndpointWritableLocked() (int, tcpip.Error) { } // Check if the connection has already been closed for sends. - if e.sndClosed { + if e.sndQueueInfo.SndClosed { return 0, &tcpip.ErrClosedForSend{} } sndBufSize := e.getSendBufferSize() - avail := sndBufSize - e.sndBufUsed + avail := sndBufSize - e.sndQueueInfo.SndBufUsed if avail <= 0 { return 0, &tcpip.ErrWouldBlock{} } @@ -1541,8 +1485,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp defer e.UnlockUser() nextSeg, n, err := func() (*segment, int, tcpip.Error) { - e.sndBufMu.Lock() - defer e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Lock() + defer e.sndQueueInfo.sndQueueMu.Unlock() avail, err := e.isEndpointWritableLocked() if err != nil { @@ -1557,8 +1501,8 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // available buffer space to be consumed by some other caller while we // are copying data in. if !opts.Atomic { - e.sndBufMu.Unlock() - defer e.sndBufMu.Lock() + e.sndQueueInfo.sndQueueMu.Unlock() + defer e.sndQueueInfo.sndQueueMu.Lock() e.UnlockUser() defer e.LockUser() @@ -1600,10 +1544,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + e.sndQueueInfo.SndBufUsed += len(v) + e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) + e.sndQueueInfo.sndQueue.PushBack(s) return e.drainSendQueueLocked(), len(v), nil }() @@ -1618,11 +1562,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. -// Precondition: e.mu and e.rcvListMu must be held. -func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { - wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked()) - maxWindow := wndFromSpace(e.rcvBufSize) - wndFromUsedBytes := maxWindow - e.rcvBufUsed +// Precondition: e.mu and e.rcvQueueMu must be held. +func (e *endpoint) selectWindowLocked(rcvBufSize int) (wnd seqnum.Size) { + wndFromAvailable := wndFromSpace(e.receiveBufferAvailableLocked(rcvBufSize)) + maxWindow := wndFromSpace(rcvBufSize) + wndFromUsedBytes := maxWindow - e.rcvQueueInfo.RcvBufUsed // We take the lesser of the wndFromAvailable and wndFromUsedBytes because in // cases where we receive a lot of small segments the segment overhead is a @@ -1640,11 +1584,11 @@ func (e *endpoint) selectWindowLocked() (wnd seqnum.Size) { return seqnum.Size(newWnd) } -// selectWindow invokes selectWindowLocked after acquiring e.rcvListMu. +// selectWindow invokes selectWindowLocked after acquiring e.rcvQueueMu. func (e *endpoint) selectWindow() (wnd seqnum.Size) { - e.rcvListMu.Lock() - wnd = e.selectWindowLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + wnd = e.selectWindowLocked(int(e.ops.GetReceiveBufferSize())) + e.rcvQueueInfo.rcvQueueMu.Unlock() return wnd } @@ -1662,9 +1606,9 @@ func (e *endpoint) selectWindow() (wnd seqnum.Size) { // above will be true if the new window is >= ACK threshold and false // otherwise. // -// Precondition: e.mu and e.rcvListMu must be held. -func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { - newAvail := int(e.selectWindowLocked()) +// Precondition: e.mu and e.rcvQueueMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int, rcvBufSize int) (crossed bool, above bool) { + newAvail := int(e.selectWindowLocked(rcvBufSize)) oldAvail := newAvail - deltaBefore if oldAvail < 0 { oldAvail = 0 @@ -1673,7 +1617,7 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // rcvBufFraction is the inverse of the fraction of receive buffer size that // is used to decide if the available buffer space is now above it. const rcvBufFraction = 2 - if wndThreshold := wndFromSpace(e.rcvBufSize / rcvBufFraction); threshold > wndThreshold { + if wndThreshold := wndFromSpace(rcvBufSize / rcvBufFraction); threshold > wndThreshold { threshold = wndThreshold } switch { @@ -1700,7 +1644,7 @@ func (e *endpoint) OnReusePortSet(v bool) { } // OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. -func (e *endpoint) OnKeepAliveSet(v bool) { +func (e *endpoint) OnKeepAliveSet(bool) { e.notifyProtocolGoroutine(notifyKeepaliveChanged) } @@ -1708,7 +1652,7 @@ func (e *endpoint) OnKeepAliveSet(v bool) { func (e *endpoint) OnDelayOptionSet(v bool) { if !v { // Handle delayed data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1716,7 +1660,7 @@ func (e *endpoint) OnDelayOptionSet(v bool) { func (e *endpoint) OnCorkOptionSet(v bool) { if !v { // Handle the corked data. - e.sndWaker.Assert() + e.sndQueueInfo.sndWaker.Assert() } } @@ -1724,6 +1668,37 @@ func (e *endpoint) getSendBufferSize() int { return int(e.ops.GetSendBufferSize()) } +// OnSetReceiveBufferSize implements tcpip.SocketOptionsHandler.OnSetReceiveBufferSize. +func (e *endpoint) OnSetReceiveBufferSize(rcvBufSz, oldSz int64) (newSz int64) { + e.LockUser() + e.rcvQueueInfo.rcvQueueMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.RcvWndScale + } + if rcvBufSz>>scale == 0 { + rcvBufSz = 1 << scale + } + + availBefore := wndFromSpace(e.receiveBufferAvailableLocked(int(oldSz))) + availAfter := wndFromSpace(e.receiveBufferAvailableLocked(int(rcvBufSz))) + e.rcvQueueInfo.RcvAutoParams.Disabled = true + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter-availBefore, int(rcvBufSz)); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + e.rcvQueueInfo.rcvQueueMu.Unlock() + e.UnlockUser() + return rcvBufSz +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 @@ -1767,56 +1742,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return &tcpip.ErrNotSupported{} } - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs tcpip.TCPReceiveBufferSizeRangeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &rs, err)) - } - - if v > rs.Max { - v = rs.Max - } - - if v < math.MaxInt32/SegOverheadFactor { - v *= SegOverheadFactor - if v < rs.Min { - v = rs.Min - } - } else { - v = math.MaxInt32 - } - - e.LockUser() - e.rcvListMu.Lock() - - // Make sure the receive buffer size allows us to send a - // non-zero window size. - scale := uint8(0) - if e.rcv != nil { - scale = e.rcv.rcvWndScale - } - if v>>scale == 0 { - v = 1 << scale - } - - availBefore := wndFromSpace(e.receiveBufferAvailableLocked()) - e.rcvBufSize = v - availAfter := wndFromSpace(e.receiveBufferAvailableLocked()) - - e.rcvAutoParams.disabled = true - - // Immediately send an ACK to uncork the sender silly window - // syndrome prevetion, when our available space grows above aMSS - // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) - } - - e.rcvListMu.Unlock() - e.UnlockUser() - case tcpip.TTLOption: e.LockUser() e.ttl = uint8(v) @@ -1959,10 +1884,10 @@ func (e *endpoint) readyReceiveSize() (int, tcpip.Error) { return 0, &tcpip.ErrInvalidEndpointState{} } - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + defer e.rcvQueueInfo.rcvQueueMu.Unlock() - return e.rcvBufUsed, nil + return e.rcvQueueInfo.RcvBufUsed, nil } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -2002,12 +1927,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() - case tcpip.ReceiveBufferSizeOption: - e.rcvListMu.Lock() - v := e.rcvBufSize - e.rcvListMu.Unlock() - return v, nil - case tcpip.TTLOption: e.LockUser() v := int(e.ttl) @@ -2043,15 +1962,15 @@ func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { // the connection did not send and receive data, then RTT will // be zero. snd.rtt.Lock() - info.RTT = snd.rtt.srtt - info.RTTVar = snd.rtt.rttvar + info.RTT = snd.rtt.TCPRTTState.SRTT + info.RTTVar = snd.rtt.TCPRTTState.RTTVar snd.rtt.Unlock() - info.RTO = snd.rto + info.RTO = snd.RTO info.CcState = snd.state - info.SndSsthresh = uint32(snd.sndSsthresh) - info.SndCwnd = uint32(snd.sndCwnd) - info.ReorderSeen = snd.rc.reorderSeen + info.SndSsthresh = uint32(snd.Ssthresh) + info.SndCwnd = uint32(snd.SndCwnd) + info.ReorderSeen = snd.rc.Reord } e.UnlockUser() return info @@ -2096,7 +2015,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) e.UnlockUser() if err != nil { return err @@ -2204,20 +2123,20 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.TransportEndpointInfo.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } defer r.Release() netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.ID.LocalAddress = r.LocalAddress() - e.ID.RemoteAddress = r.RemoteAddress() - e.ID.RemotePort = addr.Port + e.TransportEndpointInfo.ID.LocalAddress = r.LocalAddress() + e.TransportEndpointInfo.ID.RemoteAddress = r.RemoteAddress() + e.TransportEndpointInfo.ID.RemotePort = addr.Port - if e.ID.LocalPort != 0 { + if e.TransportEndpointInfo.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } @@ -2226,19 +2145,29 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // one. Make sure that it isn't one that will result in the same // address/port for both local and remote (otherwise this // endpoint would be trying to connect to itself). - sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + sameAddr := e.TransportEndpointInfo.ID.LocalAddress == e.TransportEndpointInfo.ID.RemoteAddress // Calculate a port offset based on the destination IP/port and // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.Seed()) - h.Write([]byte(e.ID.LocalAddress)) - h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) - h.Write(portBuf) - portOffset := uint16(h.Sum32()) + + h := jenkins.Sum32(e.stack.Seed()) + for _, s := range [][]byte{ + []byte(e.ID.LocalAddress), + []byte(e.ID.RemoteAddress), + portBuf, + } { + // Per io.Writer.Write: + // + // Write must return a non-nil error if it returns n < len(p). + if _, err := h.Write(s); err != nil { + panic(err) + } + } + portOffset := h.Sum32() var twReuse tcpip.TCPTimeWaitReuseOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { @@ -2249,21 +2178,21 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly { switch netProto { case header.IPv4ProtocolNumber: - reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress) + reuse = header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.LocalAddress) && header.IsV4LoopbackAddress(e.TransportEndpointInfo.ID.RemoteAddress) case header.IPv6ProtocolNumber: - reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback + reuse = e.TransportEndpointInfo.ID.LocalAddress == header.IPv6Loopback && e.TransportEndpointInfo.ID.RemoteAddress == header.IPv6Loopback } } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, tcpip.Error) { - if sameAddr && p == e.ID.RemotePort { + if sameAddr && p == e.TransportEndpointInfo.ID.RemotePort { return false, nil } portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2273,7 +2202,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } - transEPID := e.ID + transEPID := e.TransportEndpointInfo.ID transEPID.LocalPort = p // Check if an endpoint is registered with demuxer in TIME-WAIT and if // we can reuse it. If we can't find a transport endpoint then we just @@ -2310,7 +2239,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2321,13 +2250,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp } } - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { portRes := ports.Reservation{ Networks: netProtos, Transport: ProtocolNumber, - Addr: e.ID.LocalAddress, + Addr: e.TransportEndpointInfo.ID.LocalAddress, Port: p, Flags: e.portFlags, BindToDevice: bindToDevice, @@ -2342,13 +2271,14 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // Port picking successful. Save the details of // the selected port. - e.ID = id + e.TransportEndpointInfo.ID = id e.isPortReserved = true e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil }); err != nil { + e.stack.Stats().TCP.FailedPortReservations.Increment() return err } } @@ -2367,10 +2297,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { - s.id = e.ID - e.sndWaker.Assert() + s.id = e.TransportEndpointInfo.ID + e.sndQueueInfo.sndWaker.Assert() } } e.segmentQueue.mu.Unlock() @@ -2412,10 +2342,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for read. if e.shutdownFlags&tcpip.ShutdownRead != 0 { // Mark read side as closed. - e.rcvListMu.Lock() - e.rcvClosed = true - rcvBufUsed := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + rcvBufUsed := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() // If we're fully closed and we have unread data we need to abort // the connection with a RST. @@ -2429,10 +2359,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Close for write. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - e.sndBufMu.Lock() - if e.sndClosed { + e.sndQueueInfo.sndQueueMu.Lock() + if e.sndQueueInfo.SndClosed { // Already closed. - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() if e.EndpointState() == StateTimeWait { return &tcpip.ErrNotConnected{} } @@ -2440,12 +2370,12 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.ID, nil) - e.sndQueue.PushBack(s) - e.sndBufInQueue++ + s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) + e.sndQueueInfo.sndQueue.PushBack(s) + e.sndQueueInfo.SndBufInQueue++ // Mark endpoint as closed. - e.sndClosed = true - e.sndBufMu.Unlock() + e.sndQueueInfo.SndClosed = true + e.sndQueueInfo.sndQueueMu.Unlock() e.handleClose() } @@ -2458,9 +2388,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // // By not removing this endpoint from the demuxer mapping, we // ensure that any other bind to the same port fails, as on Linux. - e.rcvListMu.Lock() - e.rcvClosed = true - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = true + e.rcvQueueInfo.rcvQueueMu.Unlock() e.closePendingAcceptableConnectionsLocked() // Notify waiters that the endpoint is shutdown. e.waiterQueue.Notify(waiter.ReadableEvents | waiter.WritableEvents | waiter.EventHUp | waiter.EventErr) @@ -2491,28 +2421,20 @@ func (e *endpoint) listen(backlog int) tcpip.Error { if e.EndpointState() == StateListen && !e.closed { e.acceptMu.Lock() defer e.acceptMu.Unlock() - if e.acceptedChan == nil { + if e.accepted == (accepted{}) { // listen is called after shutdown. - e.acceptedChan = make(chan *endpoint, backlog) + e.accepted.cap = backlog e.shutdownFlags = 0 - e.rcvListMu.Lock() - e.rcvClosed = false - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + e.rcvQueueInfo.RcvClosed = false + e.rcvQueueInfo.rcvQueueMu.Unlock() } else { - // Adjust the size of the channel iff we can fix + // Adjust the size of the backlog iff we can fit // existing pending connections into the new one. - if len(e.acceptedChan) > backlog { + if e.accepted.endpoints.Len() > backlog { return &tcpip.ErrInvalidEndpointState{} } - if cap(e.acceptedChan) == backlog { - return nil - } - origChan := e.acceptedChan - e.acceptedChan = make(chan *endpoint, backlog) - close(origChan) - for ep := range origChan { - e.acceptedChan <- ep - } + e.accepted.cap = backlog } // Notify any blocked goroutines that they can attempt to @@ -2538,19 +2460,19 @@ func (e *endpoint) listen(backlog int) tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - // The channel may be non-nil when we're restoring the endpoint, and it + // The queue may be non-zero when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() - if e.acceptedChan == nil { - e.acceptedChan = make(chan *endpoint, backlog) + if e.accepted == (accepted{}) { + e.accepted.cap = backlog } e.acceptMu.Unlock() @@ -2578,24 +2500,25 @@ func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter. e.LockUser() defer e.UnlockUser() - e.rcvListMu.Lock() - rcvClosed := e.rcvClosed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := e.rcvQueueInfo.RcvClosed + e.rcvQueueInfo.rcvQueueMu.Unlock() // Endpoint must be in listen state before it can accept connections. if rcvClosed || e.EndpointState() != StateListen { return nil, nil, &tcpip.ErrInvalidEndpointState{} } // Get the new accepted endpoint. - e.acceptMu.Lock() - defer e.acceptMu.Unlock() var n *endpoint - select { - case n = <-e.acceptedChan: - e.acceptCond.Signal() - default: + e.acceptMu.Lock() + if element := e.accepted.endpoints.Front(); element != nil { + n = e.accepted.endpoints.Remove(element).(*endpoint) + } + e.acceptMu.Unlock() + if n == nil { return nil, nil, &tcpip.ErrWouldBlock{} } + e.acceptCond.Signal() if peerAddr != nil { *peerAddr = n.getRemoteAddress() } @@ -2645,7 +2568,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { if nic == 0 { return &tcpip.ErrBadLocalAddress{} } - e.ID.LocalAddress = addr.Addr + e.TransportEndpointInfo.ID.LocalAddress = addr.Addr } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) @@ -2659,7 +2582,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { Dest: tcpip.FullAddress{}, } port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { - id := e.ID + id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a // listening endpoint bound with the same id and portFlags and bindToDevice @@ -2675,6 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { return true, nil }) if err != nil { + e.stack.Stats().TCP.FailedPortReservations.Increment() return err } @@ -2684,7 +2608,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { e.boundNICID = nic e.isPortReserved = true e.effectiveNetProtos = netProtos - e.ID.LocalPort = port + e.TransportEndpointInfo.ID.LocalPort = port // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2698,8 +2622,8 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { defer e.UnlockUser() return tcpip.FullAddress{ - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, NIC: e.boundNICID, }, nil } @@ -2718,8 +2642,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, NIC: e.boundNICID, } } @@ -2758,13 +2682,13 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p Payload: pkt.Data().AsRange().ToOwnedView(), Dst: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, + Addr: e.TransportEndpointInfo.ID.RemoteAddress, + Port: e.TransportEndpointInfo.ID.RemotePort, }, Offender: tcpip.FullAddress{ NIC: pkt.NICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, + Addr: e.TransportEndpointInfo.ID.LocalAddress, + Port: e.TransportEndpointInfo.ID.LocalPort, }, NetProto: pkt.NetworkProtocolNumber, }) @@ -2777,12 +2701,12 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p // HandleError implements stack.TransportEndpoint. func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { handlePacketTooBig := func(mtu uint32) { - e.sndBufMu.Lock() - e.packetTooBigCount++ - if v := int(mtu); v < e.sndMTU { - e.sndMTU = v + e.sndQueueInfo.sndQueueMu.Lock() + e.sndQueueInfo.PacketTooBigCount++ + if v := int(mtu); v < e.sndQueueInfo.SndMTU { + e.sndQueueInfo.SndMTU = v } - e.sndBufMu.Unlock() + e.sndQueueInfo.sndQueueMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) } @@ -2801,14 +2725,14 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB // in the send buffer. The number of newly available bytes is v. func (e *endpoint) updateSndBufferUsage(v int) { sendBufferSize := e.getSendBufferSize() - e.sndBufMu.Lock() - notify := e.sndBufUsed >= sendBufferSize>>1 - e.sndBufUsed -= v + e.sndQueueInfo.sndQueueMu.Lock() + notify := e.sndQueueInfo.SndBufUsed >= sendBufferSize>>1 + e.sndQueueInfo.SndBufUsed -= v // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndBufUsed < int(sendBufferSize)>>1 - e.sndBufMu.Unlock() + notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + e.sndQueueInfo.sndQueueMu.Unlock() if notify { e.waiterQueue.Notify(waiter.WritableEvents) @@ -2819,58 +2743,50 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { - e.rcvListMu.Lock() + e.rcvQueueInfo.rcvQueueMu.Lock() if s != nil { - e.rcvBufUsed += s.payloadSize() + e.rcvQueueInfo.RcvBufUsed += s.payloadSize() s.incRef() - e.rcvList.PushBack(s) + e.rcvQueueInfo.rcvQueue.PushBack(s) } else { - e.rcvClosed = true + e.rcvQueueInfo.RcvClosed = true } - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Unlock() e.waiterQueue.Notify(waiter.ReadableEvents) } // receiveBufferAvailableLocked calculates how many bytes are still available // in the receive buffer. -// rcvListMu must be held when this function is called. -func (e *endpoint) receiveBufferAvailableLocked() int { +// rcvQueueMu must be held when this function is called. +func (e *endpoint) receiveBufferAvailableLocked(rcvBufSize int) int { // We may use more bytes than the buffer size when the receive buffer // shrinks. memUsed := e.receiveMemUsed() - if memUsed >= e.rcvBufSize { + if memUsed >= rcvBufSize { return 0 } - return e.rcvBufSize - memUsed + return rcvBufSize - memUsed } // receiveBufferAvailable calculates how many bytes are still available in the // receive buffer based on the actual memory used by all segments held in // receive buffer/pending and segment queue. func (e *endpoint) receiveBufferAvailable() int { - e.rcvListMu.Lock() - available := e.receiveBufferAvailableLocked() - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + available := e.receiveBufferAvailableLocked(int(e.ops.GetReceiveBufferSize())) + e.rcvQueueInfo.rcvQueueMu.Unlock() return available } // receiveBufferUsed returns the amount of in-use receive buffer. func (e *endpoint) receiveBufferUsed() int { - e.rcvListMu.Lock() - used := e.rcvBufUsed - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + used := e.rcvQueueInfo.RcvBufUsed + e.rcvQueueInfo.rcvQueueMu.Unlock() return used } -// receiveBufferSize returns the current size of the receive buffer. -func (e *endpoint) receiveBufferSize() int { - e.rcvListMu.Lock() - size := e.rcvBufSize - e.rcvListMu.Unlock() - return size -} - // receiveMemUsed returns the total memory in use by segments held by this // endpoint. func (e *endpoint) receiveMemUsed() int { @@ -2899,11 +2815,11 @@ func (e *endpoint) maxReceiveBufferSize() int { // receiveBuffer otherwise we use the max permissible receive buffer size to // compute the scale. func (e *endpoint) rcvWndScaleForHandshake() int { - bufSizeForScale := e.receiveBufferSize() + bufSizeForScale := e.ops.GetReceiveBufferSize() - e.rcvListMu.Lock() - autoTuningDisabled := e.rcvAutoParams.disabled - e.rcvListMu.Unlock() + e.rcvQueueInfo.rcvQueueMu.Lock() + autoTuningDisabled := e.rcvQueueInfo.RcvAutoParams.Disabled + e.rcvQueueInfo.rcvQueueMu.Unlock() if autoTuningDisabled { return FindWndScale(seqnum.Size(bufSizeForScale)) } @@ -2914,7 +2830,7 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + if e.SendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { e.setRecentTimestamp(tsVal) } } @@ -2924,7 +2840,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, // initializes the recentTS with the value provided in synOpts.TSval. func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { - e.sendTSOk = true + e.SendTSOk = true e.setRecentTimestamp(synOpts.TSVal) } } @@ -2932,7 +2848,7 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.tsOffset) + return tcpTimeStamp(time.Now(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is @@ -2971,7 +2887,7 @@ func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { return } if bool(v) && synOpts.SACKPermitted { - e.sackPermitted = true + e.SACKPermitted = true } } @@ -2985,144 +2901,70 @@ func (e *endpoint) maxOptionSize() (size int) { return size } -// completeState makes a full copy of the endpoint and returns it. This is used -// before invoking the probe. The state returned may not be fully consistent if -// there are intervening syscalls when the state is being copied. -func (e *endpoint) completeState() stack.TCPEndpointState { - var s stack.TCPEndpointState - s.SegTime = time.Now() - - // Copy EndpointID. - s.ID = stack.TCPEndpointID(e.ID) - - // Copy endpoint rcv state. - e.rcvListMu.Lock() - s.RcvBufSize = e.rcvBufSize - s.RcvBufUsed = e.rcvBufUsed - s.RcvClosed = e.rcvClosed - s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime - s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied - s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied - s.RcvAutoParams.RTT = e.rcvAutoParams.rtt - s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber - s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime - s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled - e.rcvListMu.Unlock() - - // Endpoint TCP Option state. - s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTimestamp() - s.TSOffset = e.tsOffset - s.SACKPermitted = e.sackPermitted +// completeStateLocked makes a full copy of the endpoint and returns it. This is +// used before invoking the probe. +// +// Precondition: e.mu must be held. +func (e *endpoint) completeStateLocked() stack.TCPEndpointState { + s := stack.TCPEndpointState{ + TCPEndpointStateInner: e.TCPEndpointStateInner, + ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), + SegTime: time.Now(), + Receiver: e.rcv.TCPReceiverState, + Sender: e.snd.TCPSenderState, + } + + sndBufSize := e.getSendBufferSize() + // Copy the send buffer atomically. + e.sndQueueInfo.sndQueueMu.Lock() + s.SndBufState = e.sndQueueInfo.TCPSndBufState + s.SndBufState.SndBufSize = sndBufSize + e.sndQueueInfo.sndQueueMu.Unlock() + + // Copy the receive buffer atomically. + e.rcvQueueInfo.rcvQueueMu.Lock() + s.RcvBufState = e.rcvQueueInfo.TCPRcvBufState + e.rcvQueueInfo.rcvQueueMu.Unlock() + + // Copy the endpoint TCP Option state. s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() - // Copy endpoint send state. - sndBufSize := e.getSendBufferSize() - e.sndBufMu.Lock() - s.SndBufSize = sndBufSize - s.SndBufUsed = e.sndBufUsed - s.SndClosed = e.sndClosed - s.SndBufInQueue = e.sndBufInQueue - s.PacketTooBigCount = e.packetTooBigCount - s.SndMTU = e.sndMTU - e.sndBufMu.Unlock() - - // Copy receiver state. - s.Receiver = stack.TCPReceiverState{ - RcvNxt: e.rcv.rcvNxt, - RcvAcc: e.rcv.rcvAcc, - RcvWndScale: e.rcv.rcvWndScale, - PendingBufUsed: e.rcv.pendingBufUsed, - } - - // Copy sender state. - s.Sender = stack.TCPSenderState{ - LastSendTime: e.snd.lastSendTime, - DupAckCount: e.snd.dupAckCount, - FastRecovery: stack.TCPFastRecoveryState{ - Active: e.snd.fr.active, - First: e.snd.fr.first, - Last: e.snd.fr.last, - MaxCwnd: e.snd.fr.maxCwnd, - HighRxt: e.snd.fr.highRxt, - RescueRxt: e.snd.fr.rescueRxt, - }, - SndCwnd: e.snd.sndCwnd, - Ssthresh: e.snd.sndSsthresh, - SndCAAckCount: e.snd.sndCAAckCount, - Outstanding: e.snd.outstanding, - SackedOut: e.snd.sackedOut, - SndWnd: e.snd.sndWnd, - SndUna: e.snd.sndUna, - SndNxt: e.snd.sndNxt, - RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, - RTTMeasureTime: e.snd.rttMeasureTime, - Closed: e.snd.closed, - RTO: e.snd.rto, - MaxPayloadSize: e.snd.maxPayloadSize, - SndWndScale: e.snd.sndWndScale, - MaxSentAck: e.snd.maxSentAck, - } e.snd.rtt.Lock() - s.Sender.SRTT = e.snd.rtt.srtt - s.Sender.SRTTInited = e.snd.rtt.srttInited + s.Sender.RTTState = e.snd.rtt.TCPRTTState e.snd.rtt.Unlock() if cubic, ok := e.snd.cc.(*cubicState); ok { - s.Sender.Cubic = stack.TCPCubicState{ - WMax: cubic.wMax, - WLastMax: cubic.wLastMax, - T: cubic.t, - TimeSinceLastCongestion: time.Since(cubic.t), - C: cubic.c, - K: cubic.k, - Beta: cubic.beta, - WC: cubic.wC, - WEst: cubic.wEst, - } + s.Sender.Cubic = cubic.TCPCubicState + s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) } - rc := &e.snd.rc - s.Sender.RACKState = stack.TCPRACKState{ - XmitTime: rc.xmitTime, - EndSequence: rc.endSequence, - FACK: rc.fack, - RTT: rc.rtt, - Reord: rc.reorderSeen, - DSACKSeen: rc.dsackSeen, - ReoWnd: rc.reoWnd, - ReoWndIncr: rc.reoWndIncr, - ReoWndPersist: rc.reoWndPersist, - RTTSeq: rc.rttSeq, - } + s.Sender.RACKState = e.snd.rc.TCPRACKState return s } func (e *endpoint) initHardwareGSO() { - gso := &stack.GSO{} switch e.route.NetProto() { case header.IPv4ProtocolNumber: - gso.Type = stack.GSOTCPv4 - gso.L3HdrLen = header.IPv4MinimumSize + e.gso.Type = stack.GSOTCPv4 + e.gso.L3HdrLen = header.IPv4MinimumSize case header.IPv6ProtocolNumber: - gso.Type = stack.GSOTCPv6 - gso.L3HdrLen = header.IPv6MinimumSize + e.gso.Type = stack.GSOTCPv6 + e.gso.L3HdrLen = header.IPv6MinimumSize default: panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) } - gso.NeedsCsum = true - gso.CsumOffset = header.TCPChecksumOffset - gso.MaxSize = e.route.GSOMaxSize() - e.gso = gso + e.gso.NeedsCsum = true + e.gso.CsumOffset = header.TCPChecksumOffset + e.gso.MaxSize = e.route.GSOMaxSize() } func (e *endpoint) initGSO() { if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() } else if e.route.HasSoftwareGSOCapability() { - e.gso = &stack.GSO{ + e.gso = stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, NeedsCsum: false, @@ -3200,3 +3042,17 @@ func (e *endpoint) allowOutOfWindowAck() bool { e.lastOutOfWindowAckTime = now return true } + +// GetTCPReceiveBufferLimits is used to get send buffer size limits for TCP. +func GetTCPReceiveBufferLimits(s tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + var ss tcpip.TCPReceiveBufferSizeRangeOption + if err := s.TransportProtocolOption(header.TCPProtocolNumber, &ss); err != nil { + panic(fmt.Sprintf("s.TransportProtocolOption(%d, %#v) = %s", header.TCPProtocolNumber, ss, err)) + } + + return tcpip.ReceiveBufferSizeOption{ + Min: ss.Min, + Default: ss.Default, + Max: ss.Max, + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index a53d76917..6e9777fe4 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -58,7 +58,7 @@ func (e *endpoint) beforeSave() { if !e.route.HasSaveRestoreCapability() { if !e.route.HasDisconncetOkCapability() { panic(&tcpip.ErrSaveRejection{ - Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort), + Err: fmt.Errorf("endpoint cannot be saved in connected state: local %s:%d, remote %s:%d", e.TransportEndpointInfo.ID.LocalAddress, e.TransportEndpointInfo.ID.LocalPort, e.TransportEndpointInfo.ID.RemoteAddress, e.TransportEndpointInfo.ID.RemotePort), }) } e.resetConnectionLocked(&tcpip.ErrConnectionAborted{}) @@ -67,7 +67,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if !e.workerRunning { - // The endpoint must be in acceptedChan or has been just + // The endpoint must be in the accepted queue or has been just // disconnected and closed. break } @@ -88,7 +88,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() } if e.workerRunning { - panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.TransportEndpointInfo.ID)) } default: panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) @@ -99,37 +99,19 @@ func (e *endpoint) beforeSave() { } } -// saveAcceptedChan is invoked by stateify. -func (e *endpoint) saveAcceptedChan() []*endpoint { - if e.acceptedChan == nil { - return nil - } - acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan)) - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case ep := <-e.acceptedChan: - acceptedEndpoints[i] = ep - default: - panic("endpoint acceptedChan buffer got consumed by background context") - } - } - for i := 0; i < len(acceptedEndpoints); i++ { - select { - case e.acceptedChan <- acceptedEndpoints[i]: - default: - panic("endpoint acceptedChan buffer got populated by background context") - } +// saveEndpoints is invoked by stateify. +func (a *accepted) saveEndpoints() []*endpoint { + acceptedEndpoints := make([]*endpoint, a.endpoints.Len()) + for i, e := 0, a.endpoints.Front(); e != nil; i, e = i+1, e.Next() { + acceptedEndpoints[i] = e.Value.(*endpoint) } return acceptedEndpoints } -// loadAcceptedChan is invoked by stateify. -func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { - if cap(acceptedEndpoints) > 0 { - e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints)) - for _, ep := range acceptedEndpoints { - e.acceptedChan <- ep - } +// loadEndpoints is invoked by stateify. +func (a *accepted) loadEndpoints(acceptedEndpoints []*endpoint) { + for _, ep := range acceptedEndpoints { + a.endpoints.PushBack(ep) } } @@ -183,7 +165,7 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits) + e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() epState := e.origEndpointState switch epState { @@ -198,14 +180,14 @@ func (e *endpoint) Resume(s *stack.Stack) { var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { - if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) + if rcvBufSize := e.ops.GetReceiveBufferSize(); rcvBufSize < int64(rs.Min) || rcvBufSize > int64(rs.Max) { + panic(fmt.Sprintf("endpoint rcvBufSize %d is outside the min and max allowed [%d, %d]", rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}) if err != nil { panic("unable to parse BindAddr: " + err.String()) } @@ -231,19 +213,19 @@ func (e *endpoint) Resume(s *stack.Stack) { case epState.connected(): bind() if len(e.connectingAddress) == 0 { - e.connectingAddress = e.ID.RemoteAddress + e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress // This endpoint is accepted by netstack but not yet by // the app. If the endpoint is IPv6 but the remote // address is IPv4, we need to connect as IPv6 so that // dual-stack mode can be properly activated. - if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { - e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress + if e.NetProto == header.IPv6ProtocolNumber && len(e.TransportEndpointInfo.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.TransportEndpointInfo.ID.RemoteAddress } } // Reset the scoreboard to reinitialize the sack information as // we do not restore SACK information. e.scoreboard.Reset() - err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning) + err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false, e.workerRunning) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -263,7 +245,7 @@ func (e *endpoint) Resume(s *stack.Stack) { go func() { connectedLoading.Wait() bind() - backlog := cap(e.acceptedChan) + backlog := e.accepted.cap if err := e.Listen(backlog); err != nil { panic("endpoint listening failed: " + err.String()) } @@ -281,7 +263,7 @@ func (e *endpoint) Resume(s *stack.Stack) { connectedLoading.Wait() listenLoading.Wait() bind() - err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}) + err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}) if _, ok := err.(*tcpip.ErrConnectStarted); !ok { panic("endpoint connecting failed: " + err.String()) } @@ -328,23 +310,3 @@ func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) } - -// saveMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.measureTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime { - return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) { - r.rttMeasureTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a4667906..a3d1aa1a3 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -75,63 +75,6 @@ const ( ccCubic = "cubic" ) -// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The -// value is protected by a mutex so that we can increment only when it's -// guaranteed not to go above a threshold. -type synRcvdCounter struct { - sync.Mutex - value uint64 - pending sync.WaitGroup - threshold uint64 -} - -// inc tries to increment the global number of endpoints in SYN-RCVD state. It -// succeeds if the increment doesn't make the count go beyond the threshold, and -// fails otherwise. -func (s *synRcvdCounter) inc() bool { - s.Lock() - defer s.Unlock() - if s.value >= s.threshold { - return false - } - - s.pending.Add(1) - s.value++ - - return true -} - -// dec atomically decrements the global number of endpoints in SYN-RCVD -// state. It must only be called if a previous call to inc succeeded. -func (s *synRcvdCounter) dec() { - s.Lock() - defer s.Unlock() - s.value-- - s.pending.Done() -} - -// synCookiesInUse returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func (s *synRcvdCounter) synCookiesInUse() bool { - s.Lock() - defer s.Unlock() - return s.value >= s.threshold -} - -// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. -func (s *synRcvdCounter) SetThreshold(threshold uint64) { - s.Lock() - defer s.Unlock() - s.threshold = threshold -} - -// Threshold returns the current value of synRcvdCounter.Threhsold. -func (s *synRcvdCounter) Threshold() uint64 { - s.Lock() - defer s.Unlock() - return s.threshold -} - type protocol struct { stack *stack.Stack @@ -139,6 +82,7 @@ type protocol struct { sackEnabled bool recovery tcpip.TCPRecovery delayEnabled bool + alwaysUseSynCookies bool sendBufferSize tcpip.TCPSendBufferSizeRangeOption recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string @@ -150,7 +94,6 @@ type protocol struct { minRTO time.Duration maxRTO time.Duration maxRetries uint32 - synRcvdCount synRcvdCounter synRetries uint8 dispatcher dispatcher } @@ -216,8 +159,8 @@ func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, // replyWithReset replies to the given segment with a reset segment. // // If the passed TTL is 0, then the route's default TTL will be used. -func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { - route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) +func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { + route, err := st.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) if err != nil { return err } @@ -257,7 +200,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error seq: seq, ack: ack, rcvWnd: 0, - }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) + }, buffer.VectorisedView{}, stack.GSO{}, nil /* PacketOwner */) } // SetOption implements stack.TransportProtocol.SetOption. @@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip p.mu.Unlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(*v)) + p.alwaysUseSynCookies = bool(*v) p.mu.Unlock() return nil @@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er p.mu.RUnlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.RLock() - *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies) p.mu.RUnlock() return nil @@ -507,12 +450,6 @@ func (p *protocol) Wait() { p.dispatcher.wait() } -// SynRcvdCounter returns a reference to the synRcvdCount for this protocol -// instance. -func (p *protocol) SynRcvdCounter() *synRcvdCounter { - return &p.synRcvdCount -} - // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { return parse.TCP(pkt) @@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { lingerTimeout: DefaultTCPLingerTimeout, timeWaitTimeout: DefaultTCPTimeWaitTimeout, timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, - synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 0a0d5f7a1..9e332dcf7 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -46,54 +47,16 @@ const ( // // +stateify savable type rackControl struct { - // dsackSeen indicates if the connection has seen a DSACK. - dsackSeen bool - - // endSequence is the ending TCP sequence number of the most recent - // acknowledged segment. - endSequence seqnum.Value + stack.TCPRACKState // exitedRecovery indicates if the connection is exiting loss recovery. // This flag is set if the sender is leaving the recovery after // receiving an ACK and is reset during updating of reorder window. exitedRecovery bool - // fack is the highest selectively or cumulatively acknowledged - // sequence. - fack seqnum.Value - // minRTT is the estimated minimum RTT of the connection. minRTT time.Duration - // reorderSeen indicates if reordering has been detected on this - // connection. - reorderSeen bool - - // reoWnd is the reordering window time used for recording packet - // transmission times. It is used to defer the moment at which RACK - // marks a packet lost. - reoWnd time.Duration - - // reoWndIncr is the multiplier applied to adjust reorder window. - reoWndIncr uint8 - - // reoWndPersist is the number of loss recoveries before resetting - // reorder window. - reoWndPersist int8 - - // rtt is the RTT of the most recently delivered packet on the - // connection (either cumulatively acknowledged or selectively - // acknowledged) that was not marked invalid as a possible spurious - // retransmission. - rtt time.Duration - - // rttSeq is the SND.NXT when rtt is updated. - rttSeq seqnum.Value - - // xmitTime is the latest transmission timestamp of the most recent - // acknowledged segment. - xmitTime time.Time `state:".(unixTime)"` - // tlpRxtOut indicates whether there is an unacknowledged // TLP retransmission. tlpRxtOut bool @@ -108,8 +71,8 @@ type rackControl struct { // init initializes RACK specific fields. func (rc *rackControl) init(snd *sender, iss seqnum.Value) { - rc.fack = iss - rc.reoWndIncr = 1 + rc.FACK = iss + rc.ReoWndIncr = 1 rc.snd = snd } @@ -117,7 +80,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { rtt := time.Now().Sub(seg.xmitTime) - tsOffset := rc.snd.ep.tsOffset + tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a // spurious inference which is determined by below checks: @@ -138,7 +101,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { } } - rc.rtt = rtt + rc.RTT = rtt // The sender can either track a simple global minimum of all RTT // measurements from the connection, or a windowed min-filtered value @@ -152,9 +115,9 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - rc.xmitTime = seg.xmitTime - rc.endSequence = endSeq + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + rc.XmitTime = seg.xmitTime + rc.EndSequence = endSeq } } @@ -171,18 +134,18 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // is identified. func (rc *rackControl) detectReorder(seg *segment) { endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.fack.LessThan(endSeq) { - rc.fack = endSeq + if rc.FACK.LessThan(endSeq) { + rc.FACK = endSeq return } - if endSeq.LessThan(rc.fack) && seg.xmitCount == 1 { - rc.reorderSeen = true + if endSeq.LessThan(rc.FACK) && seg.xmitCount == 1 { + rc.Reord = true } } func (rc *rackControl) setDSACKSeen(dsackSeen bool) { - rc.dsackSeen = dsackSeen + rc.DSACKSeen = dsackSeen } // shouldSchedulePTO dictates whether we should schedule a PTO or not. @@ -191,7 +154,7 @@ func (s *sender) shouldSchedulePTO() bool { // Schedule PTO only if RACK loss detection is enabled. return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 && // The connection supports SACK. - s.ep.sackPermitted && + s.ep.SACKPermitted && // The connection is not in loss recovery. (s.state != tcpip.RTORecovery && s.state != tcpip.SACKRecovery) && // The connection has no SACKed sequences in the SACK scoreboard. @@ -203,9 +166,9 @@ func (s *sender) shouldSchedulePTO() bool { func (s *sender) schedulePTO() { pto := time.Second s.rtt.Lock() - if s.rtt.srttInited && s.rtt.srtt > 0 { - pto = s.rtt.srtt * 2 - if s.outstanding == 1 { + if s.rtt.TCPRTTState.SRTTInited && s.rtt.TCPRTTState.SRTT > 0 { + pto = s.rtt.TCPRTTState.SRTT * 2 + if s.Outstanding == 1 { pto += wcDelayedACKTimeout } } @@ -230,10 +193,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } var dataSent bool - if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.outstanding < s.sndCwnd { - dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + if s.writeNext != nil && s.writeNext.xmitCount == 0 && s.Outstanding < s.SndCwnd { + dataSent = s.maybeSendSegment(s.writeNext, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { - s.outstanding += s.pCount(s.writeNext, s.maxPayloadSize) + s.Outstanding += s.pCount(s.writeNext, s.MaxPayloadSize) s.writeNext = s.writeNext.Next() } } @@ -255,10 +218,10 @@ func (s *sender) probeTimerExpired() tcpip.Error { } if highestSeqXmit != nil { - dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.sndUna.Add(s.sndWnd)) + dataSent = s.maybeSendSegment(highestSeqXmit, int(s.ep.scoreboard.SMSS()), s.SndUna.Add(s.SndWnd)) if dataSent { s.rc.tlpRxtOut = true - s.rc.tlpHighRxt = s.sndNxt + s.rc.tlpHighRxt = s.SndNxt } } } @@ -274,7 +237,7 @@ func (s *sender) probeTimerExpired() tcpip.Error { // and updates TLP state accordingly. // See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3. func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { - if !(s.ep.sackPermitted && s.rc.tlpRxtOut) { + if !(s.ep.SACKPermitted && s.rc.tlpRxtOut) { return } @@ -317,13 +280,13 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { - dsackSeen := rc.dsackSeen + dsackSeen := rc.DSACKSeen snd := rc.snd // React to DSACK once per round trip. // If SND.UNA < RACK.rtt_seq: // RACK.dsack = false - if snd.sndUna.LessThan(rc.rttSeq) { + if snd.SndUna.LessThan(rc.RTTSeq) { dsackSeen = false } @@ -333,18 +296,18 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.rtt_seq = SND.NXT // RACK.reo_wnd_persist = 16 if dsackSeen { - rc.reoWndIncr++ + rc.ReoWndIncr++ dsackSeen = false - rc.rttSeq = snd.sndNxt - rc.reoWndPersist = tcpRACKRecoveryThreshold + rc.RTTSeq = snd.SndNxt + rc.ReoWndPersist = tcpRACKRecoveryThreshold } else if rc.exitedRecovery { // Else if exiting loss recovery: // RACK.reo_wnd_persist -= 1 // If RACK.reo_wnd_persist <= 0: // RACK.reo_wnd_incr = 1 - rc.reoWndPersist-- - if rc.reoWndPersist <= 0 { - rc.reoWndIncr = 1 + rc.ReoWndPersist-- + if rc.ReoWndPersist <= 0 { + rc.ReoWndIncr = 1 } rc.exitedRecovery = false } @@ -358,14 +321,14 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // Else if RACK.pkts_sacked >= RACK.dupthresh: // RACK.reo_wnd = 0 // return - if !rc.reorderSeen { + if !rc.Reord { if snd.state == tcpip.RTORecovery || snd.state == tcpip.SACKRecovery { - rc.reoWnd = 0 + rc.ReoWnd = 0 return } - if snd.sackedOut >= nDupAckThreshold { - rc.reoWnd = 0 + if snd.SackedOut >= nDupAckThreshold { + rc.ReoWnd = 0 return } } @@ -374,11 +337,11 @@ func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { // RACK.reo_wnd = RACK.min_RTT / 4 * RACK.reo_wnd_incr // RACK.reo_wnd = min(RACK.reo_wnd, SRTT) snd.rtt.Lock() - srtt := snd.rtt.srtt + srtt := snd.rtt.TCPRTTState.SRTT snd.rtt.Unlock() - rc.reoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.reoWndIncr)) - if srtt < rc.reoWnd { - rc.reoWnd = srtt + rc.ReoWnd = time.Duration((int64(rc.minRTT) / 4) * int64(rc.ReoWndIncr)) + if srtt < rc.ReoWnd { + rc.ReoWnd = srtt } } @@ -403,8 +366,8 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { - timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.rtt + rc.reoWnd + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true numLost++ @@ -435,7 +398,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { } fastRetransmit := false - if !rc.snd.fr.active { + if !rc.snd.FastRecovery.Active { rc.snd.cc.HandleLossDetected() rc.snd.enterRecovery() fastRetransmit = true @@ -471,15 +434,15 @@ func (rc *rackControl) DoRecovery(_ *segment, fastRetransmit bool) { } // Check the congestion window after entering recovery. - if snd.outstanding >= snd.sndCwnd { + if snd.Outstanding >= snd.SndCwnd { break } - if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.sndUna.Add(snd.sndWnd)); !sent { + if sent := snd.maybeSendSegment(seg, int(snd.ep.scoreboard.SMSS()), snd.SndUna.Add(snd.SndWnd)); !sent { break } dataSent = true - snd.outstanding += snd.pCount(seg, snd.maxPayloadSize) + snd.Outstanding += snd.pCount(seg, snd.MaxPayloadSize) } snd.postXmit(dataSent, true /* shouldScheduleProbe */) diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go deleted file mode 100644 index c9dc7e773..000000000 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveXmitTime is invoked by stateify. -func (rc *rackControl) saveXmitTime() unixTime { - return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (rc *rackControl) loadXmitTime(unix unixTime) { - rc.xmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index bc6793fc6..ee2c08cd6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // receiver holds the state necessary to receive TCP segments and turn them @@ -29,26 +30,15 @@ import ( // // +stateify savable type receiver struct { + stack.TCPReceiverState ep *endpoint - rcvNxt seqnum.Value - - // rcvAcc is one beyond the last acceptable sequence number. That is, - // the "largest" sequence value that the receiver has announced to the - // its peer that it's willing to accept. This may be different than - // rcvNxt + rcvWnd if the receive window is reduced; in that case we - // have to reduce the window as we receive more data instead of - // shrinking it. - rcvAcc seqnum.Value - // rcvWnd is the non-scaled receive window last advertised to the peer. rcvWnd seqnum.Size - // rcvWUP is the rcvNxt value at the last window update sent. + // rcvWUP is the RcvNxt value at the last window update sent. rcvWUP seqnum.Value - rcvWndScale uint8 - // prevBufused is the snapshot of endpoint rcvBufUsed taken when we // advertise a receive window. prevBufUsed int @@ -58,9 +48,6 @@ type receiver struct { // pendingRcvdSegments is bounded by the receive buffer size of the // endpoint. pendingRcvdSegments segmentHeap - // pendingBufUsed tracks the total number of bytes (including segment - // overhead) currently queued in pendingRcvdSegments. - pendingBufUsed int // Time when the last ack was received. lastRcvdAckTime time.Time `state:".(unixTime)"` @@ -68,12 +55,14 @@ type receiver struct { func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { return &receiver{ - ep: ep, - rcvNxt: irs + 1, - rcvAcc: irs.Add(rcvWnd + 1), + ep: ep, + TCPReceiverState: stack.TCPReceiverState{ + RcvNxt: irs + 1, + RcvAcc: irs.Add(rcvWnd + 1), + RcvWndScale: rcvWndScale, + }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - rcvWndScale: rcvWndScale, lastRcvdAckTime: time.Now(), } } @@ -84,34 +73,34 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // r.rcvWnd could be much larger than the window size we advertised in our // outgoing packets, we should use what we have advertised for acceptability // test. - scaledWindowSize := r.rcvWnd >> r.rcvWndScale + scaledWindowSize := r.rcvWnd >> r.RcvWndScale if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. scaledWindowSize = math.MaxUint16 } - advertisedWindowSize := scaledWindowSize << r.rcvWndScale - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) + advertisedWindowSize := scaledWindowSize << r.RcvWndScale + return header.Acceptable(segSeq, segLen, r.RcvNxt, r.RcvNxt.Add(advertisedWindowSize)) } // currentWindow returns the available space in the window that was advertised // last to our peer. func (r *receiver) currentWindow() (curWnd seqnum.Size) { endOfWnd := r.rcvWUP.Add(r.rcvWnd) - if endOfWnd.LessThan(r.rcvNxt) { - // return 0 if r.rcvNxt is past the end of the previously advertised window. + if endOfWnd.LessThan(r.RcvNxt) { + // return 0 if r.RcvNxt is past the end of the previously advertised window. // This can happen because we accept a large segment completely even if // accepting it causes it to partially exceed the advertised window. return 0 } - return r.rcvNxt.Size(endOfWnd) + return r.RcvNxt.Size(endOfWnd) } // getSendParams returns the parameters needed by the sender when building // segments to send. -func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { +func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() - unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + unackLen := int(r.ep.snd.MaxSentAck.Size(r.RcvNxt)) bufUsed := r.ep.receiveBufferUsed() // Grow the right edge of the window only for payloads larger than the @@ -139,18 +128,18 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // edge, as we are still advertising a window that we think can be serviced. toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed - // Update rcvAcc only if new window is > previously advertised window. We + // Update RcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we // would end up dropping bytes that might already be in flight. // ==================================================== sequence space. // ^ ^ ^ ^ - // rcvWUP rcvNxt rcvAcc new rcvAcc + // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { - // If the new window moves the right edge, then update rcvAcc. - r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) + if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + // If the new window moves the right edge, then update RcvAcc. + r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -162,9 +151,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Stash away the non-scaled receive window as we use it for measuring // receiver's estimated RTT. r.rcvWnd = newWnd - r.rcvWUP = r.rcvNxt + r.rcvWUP = r.RcvNxt r.prevBufUsed = bufUsed - scaledWnd := r.rcvWnd >> r.rcvWndScale + scaledWnd := r.rcvWnd >> r.RcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() @@ -177,9 +166,9 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // Ensure that the stashed receive window always reflects what // is being advertised. - r.rcvWnd = scaledWnd << r.rcvWndScale + r.rcvWnd = scaledWnd << r.RcvWndScale } - return r.rcvNxt, scaledWnd + return r.RcvNxt, scaledWnd } // nonZeroWindow is called when the receive window grows from zero to nonzero; @@ -201,13 +190,13 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // If the segment doesn't include the seqnum we're expecting to // consume now, we're missing a segment. We cannot proceed until // we receive that segment though. - if !r.rcvNxt.InWindow(segSeq, segLen) { + if !r.RcvNxt.InWindow(segSeq, segLen) { return false } // Trim segment to eliminate already acknowledged data. - if segSeq.LessThan(r.rcvNxt) { - diff := segSeq.Size(r.rcvNxt) + if segSeq.LessThan(r.RcvNxt) { + diff := segSeq.Size(r.RcvNxt) segLen -= diff segSeq.UpdateForward(diff) s.sequenceNumber.UpdateForward(diff) @@ -217,35 +206,35 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Move segment to ready-to-deliver list. Wakeup any waiters. r.ep.readyToRead(s) - } else if segSeq != r.rcvNxt { + } else if segSeq != r.RcvNxt { return false } // Update the segment that we're expecting to consume. - r.rcvNxt = segSeq.Add(segLen) + r.RcvNxt = segSeq.Add(segLen) // In cases of a misbehaving sender which could send more than the // advertised window, we could end up in a situation where we get a // segment that exceeds the window advertised. Instead of partially // accepting the segment and discarding bytes beyond the advertised - // window, we accept the whole segment and make sure r.rcvAcc is moved - // forward to match r.rcvNxt to indicate that the window is now closed. + // window, we accept the whole segment and make sure r.RcvAcc is moved + // forward to match r.RcvNxt to indicate that the window is now closed. // // In absence of this check the r.acceptable() check fails and accepts // segments that should be dropped because rcvWnd is calculated as - // the size of the interval (rcvNxt, rcvAcc] which becomes extremely - // large if rcvAcc is ever less than rcvNxt. - if r.rcvAcc.LessThan(r.rcvNxt) { - r.rcvAcc = r.rcvNxt + // the size of the interval (RcvNxt, RcvAcc] which becomes extremely + // large if RcvAcc is ever less than RcvNxt. + if r.RcvAcc.LessThan(r.RcvNxt) { + r.RcvAcc = r.RcvNxt } // Trim SACK Blocks to remove any SACK information that covers // sequence numbers that have been consumed. - TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { - r.rcvNxt++ + r.RcvNxt++ // Send ACK immediately. r.ep.snd.sendAck() @@ -260,7 +249,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -280,7 +269,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum } for i := first; i < len(r.pendingRcvdSegments); i++ { - r.pendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() + r.PendingBufUsed -= r.pendingRcvdSegments[i].segMemSize() r.pendingRcvdSegments[i].decRef() // Note that slice truncation does not allow garbage collection of @@ -295,7 +284,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -323,40 +312,40 @@ func (r *receiver) updateRTT() { // estimate the round-trip time by observing the time between when a byte // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. - r.ep.rcvListMu.Lock() - if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { // New measurement. - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { - r.ep.rcvListMu.Unlock() + if r.RcvNxt.LessThan(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber) { + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. - if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { - r.ep.rcvAutoParams.rtt = rtt + if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { + r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvAutoParams.rttMeasureTime = time.Now() - r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err tcpip.Error) { - r.ep.rcvListMu.Lock() - rcvClosed := r.ep.rcvClosed || r.closed - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + rcvClosed := r.ep.rcvQueueInfo.RcvClosed || r.closed + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { case StateCloseWait, StateClosing, StateLastAck: - if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + if !s.sequenceNumber.LessThanEq(r.RcvNxt) { // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -384,17 +373,17 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // The ESTABLISHED state processing is here where if the ACK check // fails, we ignore the packet: // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + if r.ep.snd.SndNxt.LessThan(s.ackNumber) { r.ep.snd.maybeSendOutOfWindowAck(s) return true, nil } // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., - // SHUT_RD) then any data past the rcvNxt should + // SHUT_RD) then any data past the RcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.RcvNxt.LessThan(endDataSeq) { return true, &tcpip.ErrConnectionAborted{} } if state == StateFinWait1 { @@ -403,7 +392,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If it's a retransmission of an old data segment // or a pure ACK then allow it. - if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.RcvNxt) || s.logicalLen() == 0 { break } @@ -413,7 +402,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // then the only acceptable segment is a // FIN. Since FIN can technically also carry // data we verify that the segment carrying a - // FIN ends at exactly e.rcvNxt+1. + // FIN ends at exactly e.RcvNxt+1. // // From RFC793 page 25. // @@ -423,7 +412,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -435,7 +424,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // end has closed and the peer is yet to send a FIN. Hence we // compare only the payload. segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + if rcvClosed && !segEnd.LessThanEq(r.RcvNxt) { return true, nil } return false, nil @@ -477,13 +466,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if r.ep.receiveBufferAvailable() > 0 && r.pendingBufUsed < r.ep.receiveBufferSize()>>2 { - r.ep.rcvListMu.Lock() - r.pendingBufUsed += s.segMemSize() - r.ep.rcvListMu.Unlock() + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed += s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.incRef() heap.Push(&r.pendingRcvdSegments, s) - UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.RcvNxt) } // Immediately send an ack so that the peer knows it may @@ -508,15 +497,15 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { segSeq := s.sequenceNumber // Skip segment altogether if it has already been acknowledged. - if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + if !segSeq.Add(segLen-1).LessThan(r.RcvNxt) && !r.consumeSegment(s, segSeq, segLen) { break } heap.Pop(&r.pendingRcvdSegments) - r.ep.rcvListMu.Lock() - r.pendingBufUsed -= s.segMemSize() - r.ep.rcvListMu.Unlock() + r.ep.rcvQueueInfo.rcvQueueMu.Lock() + r.PendingBufUsed -= s.segMemSize() + r.ep.rcvQueueInfo.rcvQueueMu.Unlock() s.decRef() } return false, nil @@ -558,7 +547,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } @@ -569,11 +558,11 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn } // Update Timestamp if required. See RFC7323, section-4.3. - if r.ep.sendTSOk && s.parsedOptions.TS { - r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + if r.ep.SendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() @@ -584,8 +573,8 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // carries data then just send an ACK. This is according to RFC 793, // page 37. // - // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. - if segSeq != r.rcvNxt || segLen != 0 { + // NOTE: In TIME_WAIT the only acceptable sequence number is RcvNxt. + if segSeq != r.RcvNxt || segLen != 0 { r.ep.snd.sendAck() } return false, false diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go index ff39780a5..063552c7f 100644 --- a/pkg/tcpip/transport/tcp/reno.go +++ b/pkg/tcpip/transport/tcp/reno.go @@ -34,14 +34,14 @@ func newRenoCC(s *sender) *renoState { func (r *renoState) updateSlowStart(packetsAcked int) int { // Don't let the congestion window cross into the congestion // avoidance range. - newcwnd := r.s.sndCwnd + packetsAcked - if newcwnd >= r.s.sndSsthresh { - newcwnd = r.s.sndSsthresh - r.s.sndCAAckCount = 0 + newcwnd := r.s.SndCwnd + packetsAcked + if newcwnd >= r.s.Ssthresh { + newcwnd = r.s.Ssthresh + r.s.SndCAAckCount = 0 } - packetsAcked -= newcwnd - r.s.sndCwnd - r.s.sndCwnd = newcwnd + packetsAcked -= newcwnd - r.s.SndCwnd + r.s.SndCwnd = newcwnd return packetsAcked } @@ -49,19 +49,19 @@ func (r *renoState) updateSlowStart(packetsAcked int) int { // avoidance mode as described in RFC5681 section 3.1 func (r *renoState) updateCongestionAvoidance(packetsAcked int) { // Consume the packets in congestion avoidance mode. - r.s.sndCAAckCount += packetsAcked - if r.s.sndCAAckCount >= r.s.sndCwnd { - r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd - r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + r.s.SndCAAckCount += packetsAcked + if r.s.SndCAAckCount >= r.s.SndCwnd { + r.s.SndCwnd += r.s.SndCAAckCount / r.s.SndCwnd + r.s.SndCAAckCount = r.s.SndCAAckCount % r.s.SndCwnd } } // reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, // page 6, eq. 4. It is called when we detect congestion in the network. func (r *renoState) reduceSlowStartThreshold() { - r.s.sndSsthresh = r.s.outstanding / 2 - if r.s.sndSsthresh < 2 { - r.s.sndSsthresh = 2 + r.s.Ssthresh = r.s.Outstanding / 2 + if r.s.Ssthresh < 2 { + r.s.Ssthresh = 2 } } @@ -70,7 +70,7 @@ func (r *renoState) reduceSlowStartThreshold() { // were acknowledged. // Update implements congestionControl.Update. func (r *renoState) Update(packetsAcked int) { - if r.s.sndCwnd < r.s.sndSsthresh { + if r.s.SndCwnd < r.s.Ssthresh { packetsAcked = r.updateSlowStart(packetsAcked) if packetsAcked == 0 { return @@ -94,7 +94,7 @@ func (r *renoState) HandleRTOExpired() { // Reduce the congestion window to 1, i.e., enter slow-start. Per // RFC 5681, page 7, we must use 1 regardless of the value of the // initial congestion window. - r.s.sndCwnd = 1 + r.s.SndCwnd = 1 } // PostRecovery implements congestionControl.PostRecovery. diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go index 2aa708e97..d368a29fc 100644 --- a/pkg/tcpip/transport/tcp/reno_recovery.go +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -31,25 +31,25 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { snd := rr.s // We are in fast recovery mode. Ignore the ack if it's out of range. - if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // Don't count this as a duplicate if it is carrying data or // updating the window. - if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + if rcvdSeg.logicalLen() != 0 || snd.SndWnd != rcvdSeg.window { return } // Inflate the congestion window if we're getting duplicate acks // for the packet we retransmitted. - if !fastRetransmit && ack == snd.fr.first { + if !fastRetransmit && ack == snd.FastRecovery.First { // We received a dup, inflate the congestion window by 1 packet // if we're not at the max yet. Only inflate the window if // regular FastRecovery is in use, RFC6675 does not require // inflating cwnd on duplicate ACKs. - if snd.sndCwnd < snd.fr.maxCwnd { - snd.sndCwnd++ + if snd.SndCwnd < snd.FastRecovery.MaxCwnd { + snd.SndCwnd++ } return } @@ -61,7 +61,7 @@ func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { // back onto the wire. // // N.B. The retransmit timer will be reset by the caller. - snd.fr.first = ack - snd.dupAckCount = 0 + snd.FastRecovery.First = ack + snd.DupAckCount = 0 snd.resendSegment() } diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go index 9d406b0bc..cd860b5e8 100644 --- a/pkg/tcpip/transport/tcp/sack_recovery.go +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -42,14 +42,14 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen } nextSegHint := snd.writeList.Front() - for snd.outstanding < snd.sndCwnd { + for snd.Outstanding < snd.SndCwnd { var nextSeg *segment var rescueRtx bool nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) if nextSeg == nil { return dataSent } - if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + if !snd.isAssignedSequenceNumber(nextSeg) || snd.SndNxt.LessThanEq(nextSeg.sequenceNumber) { // New data being sent. // Step C.3 described below is handled by @@ -67,7 +67,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen return dataSent } dataSent = true - snd.outstanding++ + snd.Outstanding++ snd.writeNext = nextSeg.Next() continue } @@ -79,7 +79,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // "The estimate of the amount of data outstanding in the network // must be updated by incrementing pipe by the number of octets // transmitted in (C.1)." - snd.outstanding++ + snd.Outstanding++ dataSent = true snd.sendSegment(nextSeg) @@ -88,7 +88,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // We do the last part of rule (4) of NextSeg here to update // RescueRxt as until this point we don't know if we are going // to use the rescue transmission. - snd.fr.rescueRxt = snd.fr.last + snd.FastRecovery.RescueRxt = snd.FastRecovery.Last } else { // RFC 6675, Step C.2 // @@ -96,7 +96,7 @@ func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSen // HighData, HighRxt MUST be set to the highest sequence // number of the retransmitted segment unless NextSeg () // rule (4) was invoked for this retransmission." - snd.fr.highRxt = segEnd - 1 + snd.FastRecovery.HighRxt = segEnd - 1 } } return dataSent @@ -109,12 +109,12 @@ func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { } // We are in fast recovery mode. Ignore the ack if it's out of range. - if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.SndUna, snd.SndNxt+1) { return } // RFC 6675 recovery algorithm step C 1-5. - end := snd.sndUna.Add(snd.sndWnd) - dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + end := snd.SndUna.Add(snd.SndWnd) + dataSent := sr.handleSACKRecovery(snd.MaxPayloadSize, end) snd.postXmit(dataSent, true /* shouldScheduleProbe */) } diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 8edd6775b..c28641be3 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -236,20 +236,14 @@ func (s *segment) parse(skipChecksumValidation bool) bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - - verifyChecksum := true if skipChecksumValidation { s.csumValid = true - verifyChecksum = false - } - if verifyChecksum { + } else { s.csum = s.hdr.Checksum() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) - xsum = s.hdr.CalculateChecksum(xsum) - xsum = header.ChecksumVV(s.data, xsum) - s.csumValid = xsum == 0xffff + payloadChecksum := header.ChecksumVV(s.data, 0) + payloadLength := uint16(s.data.Size()) + s.csumValid = s.hdr.IsChecksumValid(s.srcAddr, s.dstAddr, payloadChecksum, payloadLength) } - s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) s.ackNumber = seqnum.Value(s.hdr.AckNumber()) s.flags = s.hdr.Flags() diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index 54545a1b1..d0d1b0b8a 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -52,12 +52,12 @@ func (q *segmentQueue) empty() bool { func (q *segmentQueue) enqueue(s *segment) bool { // q.ep.receiveBufferParams() must be called without holding q.mu to // avoid lock order inversion. - bufSz := q.ep.receiveBufferSize() + bufSz := q.ep.ops.GetReceiveBufferSize() used := q.ep.receiveMemUsed() q.mu.Lock() // Allow zero sized segments (ACK/FIN/RSTs etc even if the segment queue // is currently full). - allow := (used <= bufSz || s.payloadSize() == 0) && !q.frozen + allow := (used <= int(bufSz) || s.payloadSize() == 0) && !q.frozen if allow { q.list.PushBack(s) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index faca35892..2b32cb7b2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -85,56 +86,12 @@ type lossRecovery interface { // // +stateify savable type sender struct { + stack.TCPSenderState ep *endpoint - // lastSendTime is the timestamp when the last packet was sent. - lastSendTime time.Time `state:".(unixTime)"` - - // dupAckCount is the number of duplicated acks received. It is used for - // fast retransmit. - dupAckCount int - - // fr holds state related to fast recovery. - fr fastRecovery - // lr is the loss recovery algorithm used by the sender. lr lossRecovery - // sndCwnd is the congestion window, in packets. - sndCwnd int - - // sndSsthresh is the threshold between slow start and congestion - // avoidance. - sndSsthresh int - - // sndCAAckCount is the number of packets acknowledged during congestion - // avoidance. When enough packets have been ack'd (typically cwnd - // packets), the congestion window is incremented by one. - sndCAAckCount int - - // outstanding is the number of outstanding packets, that is, packets - // that have been sent but not yet acknowledged. - outstanding int - - // sackedOut is the number of packets which are selectively acked. - sackedOut int - - // sndWnd is the send window size. - sndWnd seqnum.Size - - // sndUna is the next unacknowledged sequence number. - sndUna seqnum.Value - - // sndNxt is the sequence number of the next segment to be sent. - sndNxt seqnum.Value - - // rttMeasureSeqNum is the sequence number being used for the latest RTT - // measurement. - rttMeasureSeqNum seqnum.Value - - // rttMeasureTime is the time when the rttMeasureSeqNum was sent. - rttMeasureTime time.Time `state:".(unixTime)"` - // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` @@ -147,17 +104,15 @@ type sender struct { // window probes. unackZeroWindowProbes uint32 `state:"nosave"` - closed bool writeNext *segment writeList segmentList resendTimer timer `state:"nosave"` resendWaker sleep.Waker `state:"nosave"` - // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", - // "round-trip time variation" and "retransmit timeout", as defined in + // rtt.TCPRTTState.SRTT and rtt.TCPRTTState.RTTVar are the "smoothed + // round-trip time", and "round-trip time variation", as defined in // section 2 of RFC 6298. rtt rtt - rto time.Duration // minRTO is the minimum permitted value for sender.rto. minRTO time.Duration @@ -168,20 +123,9 @@ type sender struct { // maxRetries is the maximum permitted retransmissions. maxRetries uint32 - // maxPayloadSize is the maximum size of the payload of a given segment. - // It is initialized on demand. - maxPayloadSize int - // gso is set if generic segmentation offload is enabled. gso bool - // sndWndScale is the number of bits to shift left when reading the send - // window size from a segment. - sndWndScale uint8 - - // maxSentAck is the maxium acknowledgement actually sent. - maxSentAck seqnum.Value - // state is the current state of congestion control for this endpoint. state tcpip.CongestionControlState @@ -209,41 +153,7 @@ type sender struct { type rtt struct { sync.Mutex `state:"nosave"` - srtt time.Duration - rttvar time.Duration - srttInited bool -} - -// fastRecovery holds information related to fast recovery from a packet loss. -// -// +stateify savable -type fastRecovery struct { - // active whether the endpoint is in fast recovery. The following fields - // are only meaningful when active is true. - active bool - - // first and last represent the inclusive sequence number range being - // recovered. - first seqnum.Value - last seqnum.Value - - // maxCwnd is the maximum value the congestion window may be inflated to - // due to duplicate acks. This exists to avoid attacks where the - // receiver intentionally sends duplicate acks to artificially inflate - // the sender's cwnd. - maxCwnd int - - // highRxt is the highest sequence number which has been retransmitted - // during the current loss recovery phase. - // See: RFC 6675 Section 2 for details. - highRxt seqnum.Value - - // rescueRxt is the highest sequence number which has been - // optimistically retransmitted to prevent stalling of the ACK clock - // when there is loss at the end of the window and no new data is - // available for transmission. - // See: RFC 6675 Section 2 for details. - rescueRxt seqnum.Value + stack.TCPRTTState } func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { @@ -253,22 +163,24 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint maxPayloadSize := int(mss) - ep.maxOptionSize() s := &sender{ - ep: ep, - sndWnd: sndWnd, - sndUna: iss + 1, - sndNxt: iss + 1, - rto: 1 * time.Second, - rttMeasureSeqNum: iss + 1, - lastSendTime: time.Now(), - maxPayloadSize: maxPayloadSize, - maxSentAck: irs + 1, - fr: fastRecovery{ - // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. - last: iss, - highRxt: iss, - rescueRxt: iss, + ep: ep, + TCPSenderState: stack.TCPSenderState{ + SndWnd: sndWnd, + SndUna: iss + 1, + SndNxt: iss + 1, + RTTMeasureSeqNum: iss + 1, + LastSendTime: time.Now(), + MaxPayloadSize: maxPayloadSize, + MaxSentAck: irs + 1, + FastRecovery: stack.TCPFastRecoveryState{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + Last: iss, + HighRxt: iss, + RescueRxt: iss, + }, + RTO: 1 * time.Second, }, - gso: ep.gso != nil, + gso: ep.gso.Type != stack.GSONone, } if s.gso { @@ -282,7 +194,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { - s.sndWndScale = uint8(sndWndScale) + s.SndWndScale = uint8(sndWndScale) } s.resendTimer.init(&s.resendWaker) @@ -294,7 +206,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // Initialize SACK Scoreboard after updating max payload size as we use // the maxPayloadSize as the smss when determining if a segment is lost // etc. - s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + s.ep.scoreboard = NewSACKScoreboard(uint16(s.MaxPayloadSize), iss) // Get Stack wide config. var minRTO tcpip.TCPMinRTOOption @@ -322,10 +234,10 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint // returns a handle to it. It also initializes the sndCwnd and sndSsThresh to // their initial values. func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { - s.sndCwnd = InitialCwnd + s.SndCwnd = InitialCwnd // Set sndSsthresh to the maximum int value, which depends on the // platform. - s.sndSsthresh = int(^uint(0) >> 1) + s.Ssthresh = int(^uint(0) >> 1) switch congestionControlName { case ccCubic: @@ -339,7 +251,7 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon // initLossRecovery initiates the loss recovery algorithm for the sender. func (s *sender) initLossRecovery() lossRecovery { - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return newSACKRecovery(s) } return newRenoRecovery(s) @@ -355,7 +267,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m -= s.ep.maxOptionSize() // We don't adjust up for now. - if m >= s.maxPayloadSize { + if m >= s.MaxPayloadSize { return } @@ -364,8 +276,8 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } - oldMSS := s.maxPayloadSize - s.maxPayloadSize = m + oldMSS := s.MaxPayloadSize + s.MaxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) } @@ -380,9 +292,9 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // maxPayloadSize. s.ep.scoreboard.smss = uint16(m) - s.outstanding -= count - if s.outstanding < 0 { - s.outstanding = 0 + s.Outstanding -= count + if s.Outstanding < 0 { + s.Outstanding = 0 } // Rewind writeNext to the first segment exceeding the MTU. Do nothing @@ -401,10 +313,10 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { nextSeg = seg } - if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Update sackedOut for new maximum payload size. - s.sackedOut -= s.pCount(seg, oldMSS) - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, oldMSS) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } } @@ -416,32 +328,32 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // sendAck sends an ACK segment. func (s *sender) sendAck() { - s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.SndNxt) } // updateRTO updates the retransmit timeout when a new roud-trip time is // available. This is done in accordance with section 2 of RFC 6298. func (s *sender) updateRTO(rtt time.Duration) { s.rtt.Lock() - if !s.rtt.srttInited { - s.rtt.rttvar = rtt / 2 - s.rtt.srtt = rtt - s.rtt.srttInited = true + if !s.rtt.TCPRTTState.SRTTInited { + s.rtt.TCPRTTState.RTTVar = rtt / 2 + s.rtt.TCPRTTState.SRTT = rtt + s.rtt.TCPRTTState.SRTTInited = true } else { - diff := s.rtt.srtt - rtt + diff := s.rtt.TCPRTTState.SRTT - rtt if diff < 0 { diff = -diff } - // Use RFC6298 standard algorithm to update rttvar and srtt when + // Use RFC6298 standard algorithm to update TCPRTTState.RTTVar and TCPRTTState.SRTT when // no timestamps are available. - if !s.ep.sendTSOk { - s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 - s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + if !s.ep.SendTSOk { + s.rtt.TCPRTTState.RTTVar = (3*s.rtt.TCPRTTState.RTTVar + diff) / 4 + s.rtt.TCPRTTState.SRTT = (7*s.rtt.TCPRTTState.SRTT + rtt) / 8 } else { // When we are taking RTT measurements of every ACK then // we need to use a modified method as specified in // https://tools.ietf.org/html/rfc7323#appendix-G - if s.outstanding == 0 { + if s.Outstanding == 0 { s.rtt.Unlock() return } @@ -449,7 +361,7 @@ func (s *sender) updateRTO(rtt time.Duration) { // terms of packets and not bytes. This is similar to // how linux also does cwnd and inflight. In practice // this approximation works as expected. - expectedSamples := math.Ceil(float64(s.outstanding) / 2) + expectedSamples := math.Ceil(float64(s.Outstanding) / 2) // alpha & beta values are the original values as recommended in // https://tools.ietf.org/html/rfc6298#section-2.3. @@ -458,17 +370,17 @@ func (s *sender) updateRTO(rtt time.Duration) { alphaPrime := alpha / expectedSamples betaPrime := beta / expectedSamples - rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() - srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() - s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) - s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + rttVar := (1-betaPrime)*s.rtt.TCPRTTState.RTTVar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.TCPRTTState.SRTT.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.TCPRTTState.RTTVar = time.Duration(rttVar * float64(time.Second)) + s.rtt.TCPRTTState.SRTT = time.Duration(srtt * float64(time.Second)) } } - s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.RTO = s.rtt.TCPRTTState.SRTT + 4*s.rtt.TCPRTTState.RTTVar s.rtt.Unlock() - if s.rto < s.minRTO { - s.rto = s.minRTO + if s.RTO < s.minRTO { + s.RTO = s.minRTO } } @@ -476,20 +388,20 @@ func (s *sender) updateRTO(rtt time.Duration) { func (s *sender) resendSegment() { // Don't use any segments we already sent to measure RTT as they may // have been affected by packets being lost. - s.rttMeasureSeqNum = s.sndNxt + s.RTTMeasureSeqNum = s.SndNxt // Resend the segment. if seg := s.writeList.Front(); seg != nil { - if seg.data.Size() > s.maxPayloadSize { - s.splitSeg(seg, s.maxPayloadSize) + if seg.data.Size() > s.MaxPayloadSize { + s.splitSeg(seg, s.MaxPayloadSize) } // See: RFC 6675 section 5 Step 4.3 // // To prevent retransmission, set both the HighRXT and RescueRXT // to the highest sequence number in the retransmitted segment. - s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 - s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.HighRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.FastRecovery.RescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 s.sendSegment(seg) s.ep.stack.Stats().TCP.FastRetransmit.Increment() s.ep.stats.SendErrors.FastRetransmit.Increment() @@ -554,15 +466,15 @@ func (s *sender) retransmitTimerExpired() bool { // Set new timeout. The timer will be restarted by the call to sendData // below. - s.rto *= 2 + s.RTO *= 2 // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 - if s.rto > s.maxRTO { - s.rto = s.maxRTO + if s.RTO > s.maxRTO { + s.RTO = s.maxRTO } // Cap RTO to remaining time. - if s.rto > remaining { - s.rto = remaining + if s.RTO > remaining { + s.RTO = remaining } // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. @@ -571,9 +483,9 @@ func (s *sender) retransmitTimerExpired() bool { // After a retransmit timeout, record the highest sequence number // transmitted in the variable recover, and exit the fast recovery // procedure if applicable. - s.fr.last = s.sndNxt - 1 + s.FastRecovery.Last = s.SndNxt - 1 - if s.fr.active { + if s.FastRecovery.Active { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. @@ -589,7 +501,7 @@ func (s *sender) retransmitTimerExpired() bool { // // We'll keep on transmitting (or retransmitting) as we get acks for // the data we transmit. - s.outstanding = 0 + s.Outstanding = 0 // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1 // @@ -663,7 +575,7 @@ func (s *sender) splitSeg(seg *segment, size int) { // window space. // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() - if seg.data.Size() > s.maxPayloadSize { + if seg.data.Size() > s.MaxPayloadSize { seg.flags ^= header.TCPFlagPsh } @@ -689,7 +601,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // transmitted (i.e. either it has no assigned sequence number // or if it does have one, it's >= the next sequence number // to be sent [i.e. >= s.sndNxt]). - if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + if !s.isAssignedSequenceNumber(seg) || s.SndNxt.LessThanEq(seg.sequenceNumber) { hint = nil break } @@ -710,7 +622,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // (1.a) S2 is greater than HighRxt // (1.b) S2 is less than highest octect covered by // any received SACK. - if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + if s.FastRecovery.HighRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { @@ -743,7 +655,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // unSACKed sequence number SHOULD be returned, and // RescueRxt set to RecoveryPoint. HighRxt MUST NOT // be updated. - if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s.FastRecovery.RescueRxt.LessThan(s.SndUna - 1) { if s4 != nil { if s4.sequenceNumber.LessThan(segSeq) { s4 = seg @@ -763,7 +675,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // previously unsent data starting with sequence number // HighData+1 MUST be returned." for seg := s.writeNext; seg != nil; seg = seg.Next() { - if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.SndNxt) { continue } // We do not split the segment here to <= smss as it has @@ -788,7 +700,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se if !s.isAssignedSequenceNumber(seg) { // Merge segments if allowed. if seg.data.Size() != 0 { - available := int(s.sndNxt.Size(end)) + available := int(s.SndNxt.Size(end)) if available > limit { available = limit } @@ -816,7 +728,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && s.ep.ops.GetDelayOption() { + if s.Outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -835,7 +747,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { + if seg.data.Size() < s.MaxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -843,7 +755,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Assign flags. We don't do it above so that we can merge // additional data if Nagle holds the segment. - seg.sequenceNumber = s.sndNxt + seg.sequenceNumber = s.SndNxt seg.flags = header.TCPFlagAck | header.TCPFlagPsh } @@ -893,12 +805,12 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // the segment right here if there are no pending segments. If // there are pending segments, segment transmits are deferred to // the retransmit timer handler. - if s.sndUna != s.sndNxt { + if s.SndUna != s.SndNxt { switch { case available >= seg.data.Size(): // OK to send, the whole segments fits in the // receiver's advertised window. - case available >= s.maxPayloadSize: + case available >= s.MaxPayloadSize: // OK to send, at least 1 MSS sized segment fits // in the receiver's advertised window. default: @@ -918,8 +830,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // If GSO is not in use then cap available to // maxPayloadSize. When GSO is in use the gVisor GSO logic or // the host GSO logic will cap the segment to the correct size. - if s.ep.gso == nil && available > s.maxPayloadSize { - available = s.maxPayloadSize + if s.ep.gso.Type == stack.GSONone && available > s.MaxPayloadSize { + available = s.MaxPayloadSize } if seg.data.Size() > available { @@ -933,8 +845,8 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // Update sndNxt if we actually sent new data (as opposed to // retransmitting some previously sent data). - if s.sndNxt.LessThan(segEnd) { - s.sndNxt = segEnd + if s.SndNxt.LessThan(segEnd) { + s.SndNxt = segEnd } return true @@ -945,9 +857,9 @@ func (s *sender) sendZeroWindowProbe() { s.unackZeroWindowProbes++ // Send a zero window probe with sequence number pointing to // the last acknowledged byte. - s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.SndUna-1, ack, win) // Rearm the timer to continue probing. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) enableZeroWindowProbing() { @@ -958,7 +870,7 @@ func (s *sender) enableZeroWindowProbing() { if s.firstRetransmittedSegXmitTime.IsZero() { s.firstRetransmittedSegXmitTime = time.Now() } - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } func (s *sender) disableZeroWindowProbing() { @@ -978,12 +890,12 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // If the sender has advertized zero receive window and we have // data to be sent out, start zero window probing to query the // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { + if s.writeNext != nil && s.SndWnd == 0 { s.enableZeroWindowProbing() } // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { + if s.SndUna == s.SndNxt { s.ep.resetKeepaliveTimer(false) } else { // Enable timers if we have pending data. @@ -992,10 +904,10 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { s.schedulePTO() } else if !s.resendTimer.enabled() { s.probeTimer.disable() - if s.outstanding > 0 { + if s.Outstanding > 0 { // Enable the resend timer if it's not enabled yet and there is // outstanding data. - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1004,29 +916,29 @@ func (s *sender) postXmit(dataSent bool, shouldScheduleProbe bool) { // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { - limit := s.maxPayloadSize + limit := s.MaxPayloadSize if s.gso { limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) } - end := s.sndUna.Add(s.sndWnd) + end := s.SndUna.Add(s.SndWnd) // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && s.state != tcpip.RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { - if s.sndCwnd > InitialCwnd { - s.sndCwnd = InitialCwnd + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if s.SndCwnd > InitialCwnd { + s.SndCwnd = InitialCwnd } } var dataSent bool - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + for seg := s.writeNext; seg != nil && s.Outstanding < s.SndCwnd; seg = seg.Next() { + cwndLimit := (s.SndCwnd - s.Outstanding) * s.MaxPayloadSize if cwndLimit < limit { limit = cwndLimit } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + if s.isAssignedSequenceNumber(seg) && s.ep.SACKPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { // Move writeNext along so that we don't try and scan data that // has already been SACKED. s.writeNext = seg.Next() @@ -1036,7 +948,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg, s.maxPayloadSize) + s.Outstanding += s.pCount(seg, s.MaxPayloadSize) s.writeNext = seg.Next() } @@ -1044,21 +956,21 @@ func (s *sender) sendData() { } func (s *sender) enterRecovery() { - s.fr.active = true + s.FastRecovery.Active = true // Save state to reflect we're now in fast recovery. // // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. - s.sndCwnd = s.sndSsthresh + 3 - s.sackedOut = 0 - s.dupAckCount = 0 - s.fr.first = s.sndUna - s.fr.last = s.sndNxt - 1 - s.fr.maxCwnd = s.sndCwnd + s.outstanding - s.fr.highRxt = s.sndUna - s.fr.rescueRxt = s.sndUna - if s.ep.sackPermitted { + s.SndCwnd = s.Ssthresh + 3 + s.SackedOut = 0 + s.DupAckCount = 0 + s.FastRecovery.First = s.SndUna + s.FastRecovery.Last = s.SndNxt - 1 + s.FastRecovery.MaxCwnd = s.SndCwnd + s.Outstanding + s.FastRecovery.HighRxt = s.SndUna + s.FastRecovery.RescueRxt = s.SndUna + if s.ep.SACKPermitted { s.state = tcpip.SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() // Set TLPRxtOut to false according to @@ -1075,12 +987,12 @@ func (s *sender) enterRecovery() { } func (s *sender) leaveRecovery() { - s.fr.active = false - s.fr.maxCwnd = 0 - s.dupAckCount = 0 + s.FastRecovery.Active = false + s.FastRecovery.MaxCwnd = 0 + s.DupAckCount = 0 // Deflate cwnd. It had been artificially inflated when new dups arrived. - s.sndCwnd = s.sndSsthresh + s.SndCwnd = s.Ssthresh s.cc.PostRecovery() } @@ -1099,7 +1011,7 @@ func (s *sender) isAssignedSequenceNumber(seg *segment) bool { func (s *sender) SetPipe() { // If SACK isn't permitted or it is permitted but recovery is not active // then ignore pipe calculations. - if !s.ep.sackPermitted || !s.fr.active { + if !s.ep.SACKPermitted || !s.FastRecovery.Active { return } pipe := 0 @@ -1119,7 +1031,7 @@ func (s *sender) SetPipe() { // After initializing pipe to zero, the following steps are // taken for each octet 'S1' in the sequence space between // HighACK and HighData that has not been SACKed: - if !s1.sequenceNumber.LessThan(s.sndNxt) { + if !s1.sequenceNumber.LessThan(s.SndNxt) { break } if s.ep.scoreboard.IsSACKED(sb) { @@ -1138,20 +1050,20 @@ func (s *sender) SetPipe() { } // SetPipe(): // (b) If S1 <= HighRxt, Pipe is incremented by 1. - if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + if s1.sequenceNumber.LessThanEq(s.FastRecovery.HighRxt) { pipe++ } } } - s.outstanding = pipe + s.Outstanding = pipe } // shouldEnterRecovery returns true if the sender should enter fast recovery // based on dupAck count and sack scoreboard. // See RFC 6675 section 5. func (s *sender) shouldEnterRecovery() bool { - return s.dupAckCount >= nDupAckThreshold || - (s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.sndUna)) + return s.DupAckCount >= nDupAckThreshold || + (s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 && s.ep.scoreboard.IsLost(s.SndUna)) } // detectLoss is called when an ack is received and returns whether a loss is @@ -1163,24 +1075,24 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // If RACK is enabled and there is no reordering we should honor the // three duplicate ACK rule to enter recovery. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-4 - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { - if s.rc.reorderSeen { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.rc.Reord { return false } } if !s.isDupAck(seg) { - s.dupAckCount = 0 + s.DupAckCount = 0 return false } - s.dupAckCount++ + s.DupAckCount++ // Do not enter fast recovery until we reach nDupAckThreshold or the // first unacknowledged byte is considered lost as per SACK scoreboard. if !s.shouldEnterRecovery() { // RFC 6675 Step 3. - s.fr.highRxt = s.sndUna - 1 + s.FastRecovery.HighRxt = s.SndUna - 1 // Do run SetPipe() to calculate the outstanding segments. s.SetPipe() s.state = tcpip.Disorder @@ -1196,8 +1108,8 @@ func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { // Note that we only enter recovery when at least one more byte of data // beyond s.fr.last (the highest byte that was outstanding when fast // retransmit was last entered) is acked. - if !s.fr.last.LessThan(seg.ackNumber - 1) { - s.dupAckCount = 0 + if !s.FastRecovery.Last.LessThan(seg.ackNumber - 1) { + s.DupAckCount = 0 return false } s.cc.HandleLossDetected() @@ -1212,22 +1124,22 @@ func (s *sender) isDupAck(seg *segment) bool { // can leverage the SACK information to determine when an incoming ACK is a // "duplicate" (e.g., if the ACK contains previously unknown SACK // information). - if s.ep.sackPermitted && !seg.hasNewSACKInfo { + if s.ep.SACKPermitted && !seg.hasNewSACKInfo { return false } // (a) The receiver of the ACK has outstanding data. - return s.sndUna != s.sndNxt && + return s.SndUna != s.SndNxt && // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). - seg.ackNumber == s.sndUna && + seg.ackNumber == s.SndUna && // (e) the advertised window in the incoming acknowledgment equals the // advertised window in the last incoming acknowledgment. - s.sndWnd == seg.window + s.SndWnd == seg.window } // Iterate the writeList and update RACK for each segment which is newly acked @@ -1267,7 +1179,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) seg.acked = true - s.sackedOut += s.pCount(seg, s.maxPayloadSize) + s.SackedOut += s.pCount(seg, s.MaxPayloadSize) } seg = seg.Next() } @@ -1322,18 +1234,18 @@ func checkDSACK(rcvdSeg *segment) bool { // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.rttMeasureTime)) - s.rttMeasureSeqNum = s.sndNxt + if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { + s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.RTTMeasureSeqNum = s.SndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { - s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.MaxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: @@ -1347,7 +1259,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.SndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.SndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) rcvdSeg.hasNewSACKInfo = true } @@ -1375,10 +1287,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { ack := rcvdSeg.ackNumber fastRetransmit := false // Do not leave fast recovery, if the ACK is out of range. - if s.fr.active { + if s.FastRecovery.Active { // Leave fast recovery if it acknowledges all the data covered by // this fast recovery session. - if (ack-1).InRange(s.sndUna, s.sndNxt) && s.fr.last.LessThan(ack) { + if (ack-1).InRange(s.SndUna, s.SndNxt) && s.FastRecovery.Last.LessThan(ack) { s.leaveRecovery() } } else { @@ -1392,28 +1304,28 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Stash away the current window size. - s.sndWnd = rcvdSeg.window + s.SndWnd = rcvdSeg.window // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. if s.zeroWindowProbing && rcvdSeg.window > 0 && - (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + (ack == s.SndUna || (ack-1).InRange(s.SndUna, s.SndNxt)) { s.disableZeroWindowProbing() } // On receiving the ACK for the zero window probe, account for it and // skip trying to send any segment as we are still probing for // receive window to become non-zero. - if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.SndUna { s.unackZeroWindowProbes-- return } // Ignore ack if it doesn't acknowledge any new data. - if (ack - 1).InRange(s.sndUna, s.sndNxt) { - s.dupAckCount = 0 + if (ack - 1).InRange(s.SndUna, s.SndNxt) { + s.DupAckCount = 0 // See : https://tools.ietf.org/html/rfc1323#section-3.3. // Specifically we should only update the RTO using TSEcr if the @@ -1423,7 +1335,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { + if s.ep.SendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond @@ -1438,12 +1350,12 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // When an ack is received we must rearm the timer. // RFC 6298 5.3 s.probeTimer.disable() - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } // Remove all acknowledged data from the write list. - acked := s.sndUna.Size(ack) - s.sndUna = ack + acked := s.SndUna.Size(ack) + s.SndUna = ack // The remote ACK-ing at least 1 byte is an indication that we have a // full-duplex connection to the remote as the only way we will receive an @@ -1457,7 +1369,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } ackLeft := acked - originalOutstanding := s.outstanding + originalOutstanding := s.Outstanding for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1466,10 +1378,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg, s.maxPayloadSize) + prevCount := s.pCount(seg, s.MaxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) + s.Outstanding -= prevCount - s.pCount(seg, s.MaxPayloadSize) break } @@ -1478,7 +1390,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } // Update the RACK fields if SACK is enabled. - if s.ep.sackPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && !seg.acked && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { s.rc.update(seg, rcvdSeg) s.rc.detectReorder(seg) } @@ -1488,10 +1400,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). - if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg, s.maxPayloadSize) + if !s.ep.SACKPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.Outstanding -= s.pCount(seg, s.MaxPayloadSize) } else { - s.sackedOut -= s.pCount(seg, s.maxPayloadSize) + s.SackedOut -= s.pCount(seg, s.MaxPayloadSize) } seg.decRef() ackLeft -= datalen @@ -1501,13 +1413,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.ep.updateSndBufferUsage(int(acked)) // Clear SACK information for all acked data. - s.ep.scoreboard.Delete(s.sndUna) + s.ep.scoreboard.Delete(s.SndUna) // If we are not in fast recovery then update the congestion // window based on the number of acknowledged packets. - if !s.fr.active { - s.cc.Update(originalOutstanding - s.outstanding) - if s.fr.last.LessThan(s.sndUna) { + if !s.FastRecovery.Active { + s.cc.Update(originalOutstanding - s.Outstanding) + if s.FastRecovery.Last.LessThan(s.SndUna) { s.state = tcpip.Open // Update RACK when we are exiting fast or RTO // recovery as described in the RFC @@ -1522,16 +1434,16 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // It is possible for s.outstanding to drop below zero if we get // a retransmit timeout, reset outstanding to zero but later // get an ack that cover previously sent data. - if s.outstanding < 0 { - s.outstanding = 0 + if s.Outstanding < 0 { + s.Outstanding = 0 } s.SetPipe() // If all outstanding data was acknowledged the disable the timer. // RFC 6298 Rule 5.3 - if s.sndUna == s.sndNxt { - s.outstanding = 0 + if s.SndUna == s.SndNxt { + s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() @@ -1539,7 +1451,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { } } - if s.ep.sackPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { + if s.ep.SACKPermitted && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 { // Update RACK reorder window. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: @@ -1549,7 +1461,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the // reorder window. - if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.fr.active { + if numLost := s.rc.detectLoss(rcvdSeg.rcvdTime); numLost > 0 && !s.FastRecovery.Active { // If any segment is marked as lost by // RACK, enter recovery and retransmit // the lost segments. @@ -1558,19 +1470,19 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { fastRetransmit = true } - if s.fr.active { + if s.FastRecovery.Active { s.rc.DoRecovery(nil, fastRetransmit) } } // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if s.fr.active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { + if s.FastRecovery.Active && s.ep.tcpRecovery&tcpip.TCPRACKLossDetection == 0 { s.lr.DoRecovery(rcvdSeg, fastRetransmit) // When SACK is enabled data sending is governed by steps in // RFC 6675 Section 5 recovery steps A-C. // See: https://tools.ietf.org/html/rfc6675#section-5. - if s.ep.sackPermitted { + if s.ep.SACKPermitted { return } } @@ -1587,7 +1499,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { if seg.xmitCount > 0 { s.ep.stack.Stats().TCP.Retransmits.Increment() s.ep.stats.SendErrors.Retransmits.Increment() - if s.sndCwnd < s.sndSsthresh { + if s.SndCwnd < s.Ssthresh { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } @@ -1601,11 +1513,11 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // then use the conservative timer described in RFC6675 Section 6.0, // otherwise follow the standard time described in RFC6298 Section 5.1. if err != nil && seg.data.Size() != 0 { - if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { - s.resendTimer.enable(s.rto) + if s.FastRecovery.Active && seg.xmitCount > 1 && s.ep.SACKPermitted { + s.resendTimer.enable(s.RTO) } else { if !s.resendTimer.enabled() { - s.resendTimer.enable(s.rto) + s.resendTimer.enable(s.RTO) } } } @@ -1616,15 +1528,15 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.lastSendTime = time.Now() - if seq == s.rttMeasureSeqNum { - s.rttMeasureTime = s.lastSendTime + s.LastSendTime = time.Now() + if seq == s.RTTMeasureSeqNum { + s.RTTMeasureTime = s.LastSendTime } rcvNxt, rcvWnd := s.ep.rcv.getSendParams() // Remember the max sent ack. - s.maxSentAck = rcvNxt + s.MaxSentAck = rcvNxt return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go index ba41cff6d..2f805d8ce 100644 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -24,26 +24,6 @@ type unixTime struct { nano int64 } -// saveLastSendTime is invoked by stateify. -func (s *sender) saveLastSendTime() unixTime { - return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (s *sender) loadLastSendTime(unix unixTime) { - s.lastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRttMeasureTime is invoked by stateify. -func (s *sender) saveRttMeasureTime() unixTime { - return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()} -} - -// loadRttMeasureTime is invoked by stateify. -func (s *sender) loadRttMeasureTime(unix unixTime) { - s.rttMeasureTime = time.Unix(unix.second, unix.nano) -} - // afterLoad is invoked by stateify. func (s *sender) afterLoad() { s.resendTimer.init(&s.resendWaker) diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 5cdd5b588..c58361bc1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -33,6 +33,7 @@ const ( tsOptionSize = 12 maxTCPOptionSize = 40 mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload + latency = 5 * time.Millisecond ) func setStackRACKPermitted(t *testing.T, c *context.Context) { @@ -182,6 +183,9 @@ func sendAndReceiveWithSACK(t *testing.T, c *context.Context, numPackets int, en for i := 0; i < numPackets; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload + // This delay is added to increase RTT as low RTT can cause TLP + // before sending ACK. + time.Sleep(latency) } return data @@ -479,7 +483,7 @@ func TestRACKOnePacketTailLoss(t *testing.T) { }{ // #3 was retransmitted as TLP. {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 0}, - {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 0}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, {tcpStats.TLPRecovery, "stats.TCP.TLPRecovery", 0}, // RTO should not have fired. {tcpStats.Timeouts, "stats.TCP.Timeouts", 0}, @@ -852,8 +856,8 @@ func addReorderWindowCheckerProbe(c *context.Context, numACK int, probeDone chan return } - if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.SRTT { - probeDone <- fmt.Errorf("got RACKState.ReoWnd: %v, expected it to be greater than 0 and less than %v", state.Sender.RACKState.ReoWnd, state.Sender.SRTT) + if state.Sender.RACKState.ReoWnd == 0 || state.Sender.RACKState.ReoWnd > state.Sender.RTTState.SRTT { + probeDone <- fmt.Errorf("got RACKState.ReoWnd: %d, expected it to be greater than 0 and less than %d", state.Sender.RACKState.ReoWnd, state.Sender.RTTState.SRTT) return } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 81f800cad..20c9761f2 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9c23469f2..9916182e3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" + tcpiptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" "gvisor.dev/gvisor/pkg/test/testutil" @@ -86,7 +87,7 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } for w.N != 0 { _, err := e.ep.Read(&w, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for receive to be notified. select { case <-notifyRead: @@ -129,8 +130,8 @@ func TestGiveUpConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -144,8 +145,8 @@ func TestGiveUpConnect(t *testing.T) { // and stats updates. { err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAborted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrAborted{}) + if d := cmp.Diff(&tcpip.ErrAborted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -158,6 +159,76 @@ func TestGiveUpConnect(t *testing.T) { } } +// Test for ICMP error handling without completing handshake. +func TestConnectICMPError(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventHUp) + defer wq.EventUnregister(&waitEntry) + + { + err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) + } + } + + syn := c.GetPacket() + checker.IPv4(t, syn, checker.TCP(checker.TCPFlags(header.TCPFlagSyn))) + + wep := ep.(interface { + StopWork() + ResumeWork() + LastErrorLocked() tcpip.Error + }) + + // Stop the protocol loop, ensure that the ICMP error is processed and + // the last ICMP error is read before the loop is resumed. This sanity + // tests the handshake completion logic on ICMP errors. + wep.StopWork() + + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, nil, syn, defaultMTU) + + for { + if err := wep.LastErrorLocked(); err != nil { + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("ep.LastErrorLocked() mismatch (-want +got):\n%s", d) + } + break + } + time.Sleep(time.Millisecond) + } + + wep.ResumeWork() + + <-notifyCh + + // The stack would have unregistered the endpoint because of the ICMP error. + // Expect a RST for any subsequent packets sent to the endpoint. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(context.TestInitialSequenceNumber) + 1, + AckNum: c.IRS + 1, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPSeqNum(uint32(c.IRS+1)), + checker.TCPAckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + func TestConnectIncrementActiveConnection(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -201,8 +272,8 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrNoRoute); !ok { - t.Errorf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrNoRoute{}) + if d := cmp.Diff(&tcpip.ErrNoRoute{}, err); d != "" { + t.Errorf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -392,7 +463,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -929,17 +1000,14 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { } // Get expected window size. - rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err) - } + rcvBufSize := c.EP.SocketOptions().GetReceiveBufferSize() ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort} { err := c.EP.Connect(connectAddr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("Connect(%+v): %s", connectAddr, err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("Connect(%+v) mismatch (-want +got):\n%s", connectAddr, d) } } @@ -955,11 +1023,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { // when completing the handshake for a new TCP connection from a TCP // listening socket. It should be present in the sent TCP SYN-ACK segment. func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const ( - nonSynCookieAccepts = 2 - totalAccepts = 4 - mtu = 5000 - ) + const mtu = 5000 ips := []struct { name string @@ -1033,12 +1097,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { ip.createEP(c) - // Set the SynRcvd threshold to force a syn cookie based accept to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) } @@ -1048,13 +1106,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { t.Fatalf("Bind(%+v): %s:", bindAddr, err) } - if err := c.EP.Listen(totalAccepts); err != nil { - t.Fatalf("Listen(%d): %s:", totalAccepts, err) + backlog := 5 + // Keep the number of client requests twice to the backlog + // such that half of the connections do not use syncookies + // and the other half does. + clientConnects := backlog * 2 + + if err := c.EP.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s:", backlog, err) } - // The first nonSynCookieAccepts packets sent will trigger a gorooutine - // based accept. The rest will trigger a cookie based accept. - for i := 0; i < totalAccepts; i++ { + for i := 0; i < clientConnects; i++ { // Send a SYN requests. iss := seqnum.Value(i) srcPort := context.TestPort + uint16(i) @@ -1297,6 +1359,98 @@ func TestListenShutdown(t *testing.T) { )) } +var _ waiter.EntryCallback = (callback)(nil) + +type callback func(*waiter.Entry, waiter.EventMask) + +func (cb callback) Callback(entry *waiter.Entry, mask waiter.EventMask) { + cb(entry, mask) +} + +func TestListenerReadinessOnEvent(t *testing.T) { + s := stack.New(stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) + { + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + const id = 1 + if err := s.CreateNIC(id, ep); err != nil { + t.Fatalf("CreateNIC(%d, %T): %s", id, ep, err) + } + if err := s.AddAddress(id, ipv4.ProtocolNumber, context.StackAddr); err != nil { + t.Fatalf("AddAddress(%d, ipv4.ProtocolNumber, %s): %s", id, context.StackAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: header.IPv4EmptySubnet, NIC: id}, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr}); err != nil { + t.Fatalf("Bind(%s): %s", context.StackAddr, err) + } + const backlog = 1 + if err := ep.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s", backlog, err) + } + + address, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress(): %s", err) + } + + conn, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, _): %s", err) + } + defer conn.Close() + + events := make(chan waiter.EventMask) + // Scope `entry` to allow a binding of the same name below. + { + entry := waiter.Entry{Callback: callback(func(_ *waiter.Entry, mask waiter.EventMask) { + events <- ep.Readiness(mask) + })} + wq.EventRegister(&entry, waiter.EventIn) + defer wq.EventUnregister(&entry) + } + + entry, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&entry, waiter.EventOut) + defer wq.EventUnregister(&entry) + + switch err := conn.Connect(address).(type) { + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect(%#v): %v", address, err) + } + + // Read at least one event. + got := <-events + for { + select { + case event := <-events: + got |= event + continue + case <-ch: + if want := waiter.ReadableEvents; got != want { + t.Errorf("observed events = %b, want %b", got, want) + } + } + break + } +} + // TestListenCloseWhileConnect tests for the listening endpoint to // drain the accept-queue when closed. This should reset all of the // pending connections that are waiting to be accepted. @@ -1459,8 +1613,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.WQ.EventUnregister(&waitEntry) err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -1520,8 +1674,8 @@ func TestSynSent(t *testing.T) { addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} err := c.EP.Connect(addr) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got Connect(%+v) = %v, want %s", addr, err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(err, &tcpip.ErrConnectStarted{}); d != "" { + t.Fatalf("Connect(...) mismatch (-want +got):\n%s", d) } // Receive SYN packet. @@ -1993,9 +2147,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Bump up the receive buffer size such that, when the receive window grows, // the scaled window exceeds maxUint16. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(opt.Max), true) // Keep the payload size < segment overhead and such that it is a multiple // of the window scaled value. This enables the test to perform equality @@ -2115,9 +2267,7 @@ func TestNoWindowShrinking(t *testing.T) { initialWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() << c.RcvdWindowScale initialLastAcceptableSeq := iss.Add(seqnum.Size(initialWnd)) // Now shrink the receive buffer to half its original size. - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufSize/2); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(rcvBufSize/2), true) data := generateRandomPayload(t, rcvBufSize) // Send a payload of half the size of rcvBufSize. @@ -2373,9 +2523,7 @@ func TestScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2395,7 +2543,7 @@ func TestScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -2447,9 +2595,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer ep.Close() // Set the window size greater than the maximum non-scaled window. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) - } + ep.SocketOptions().SetReceiveBufferSize(65535*3, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -2469,7 +2615,7 @@ func TestNonScaledWindowAccept(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3001,8 +3147,8 @@ func TestSetTTL(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("unexpected return value from Connect: %s", err) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3042,9 +3188,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // Set the buffer size to a deterministic size so that we can check the // window scaling option. const rcvBufferSize = 0x20000 - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + ep.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) @@ -3063,7 +3207,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3087,11 +3231,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, mtu) defer c.Cleanup() - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(0) + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -3119,7 +3261,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -3185,9 +3327,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 const wndScale = 3 - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) - } + c.EP.SocketOptions().SetReceiveBufferSize(rcvBufferSize, true) // Start connection attempt. we, ch := waiter.NewChannelEntry(nil) @@ -3196,8 +3336,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -3315,8 +3455,8 @@ loop: case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Read() = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } break loop case <-time.After(1 * time.Second): @@ -3366,8 +3506,8 @@ func TestSendOnResetConnection(t *testing.T) { var r bytes.Reader r.Reset(make([]byte, 10)) _, err := c.EP.Write(&r, tcpip.WriteOptions{}) - if _, ok := err.(*tcpip.ErrConnectionReset); !ok { - t.Fatalf("got c.EP.Write(...) = %v, want = %s", err, &tcpip.ErrConnectionReset{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Write(...) mismatch (-want +got):\n%s", d) } } @@ -4320,8 +4460,8 @@ func TestReadAfterClosedState(t *testing.T) { var buf bytes.Buffer { _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}) - if _, ok := err.(*tcpip.ErrClosedForReceive); !ok { - t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, &tcpip.ErrClosedForReceive{}) + if d := cmp.Diff(&tcpip.ErrClosedForReceive{}, err); d != "" { + t.Fatalf("c.EP.Read(_, {Peek: true}) mismatch (-want +got):\n%s", d) } } } @@ -4365,8 +4505,8 @@ func TestReusePort(t *testing.T) { } { err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got c.EP.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("c.EP.Connect(...) mismatch (-want +got):\n%s", d) } } c.EP.Close() @@ -4411,11 +4551,7 @@ func TestReusePort(t *testing.T) { func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { t.Helper() - s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - t.Fatalf("GetSockOpt failed: %s", err) - } - + s := ep.SocketOptions().GetReceiveBufferSize() if int(s) != v { t.Fatalf("got receive buffer size = %d, want = %d", s, v) } @@ -4521,10 +4657,7 @@ func TestMinMaxBufferSizes(t *testing.T) { } // Set values below the min/2. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 99); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(99, true) checkRecvBufferSize(t, ep, 200) ep.SocketOptions().SetSendBufferSize(149, true) @@ -4532,15 +4665,11 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, 300) // Set values above the max. - if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) - } - + ep.SocketOptions().SetReceiveBufferSize(1+tcp.DefaultReceiveBufferSize*20, true) // Values above max are capped at max and then doubled. checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20*2) ep.SocketOptions().SetSendBufferSize(1+tcp.DefaultSendBufferSize*30, true) - // Values above max are capped at max and then doubled. checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30*2) } @@ -4665,8 +4794,8 @@ func TestSelfConnect(t *testing.T) { { err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) - if _, ok := err.(*tcpip.ErrConnectStarted); !ok { - t.Fatalf("got ep.Connect(...) = %v, want = %s", err, &tcpip.ErrConnectStarted{}) + if d := cmp.Diff(&tcpip.ErrConnectStarted{}, err); d != "" { + t.Fatalf("ep.Connect(...) mismatch (-want +got):\n%s", d) } } @@ -4814,7 +4943,13 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("unknown address type: '%s'", candidateAddressType) } - start, end := s.PortRange() + const ( + start = 16000 + end = 16050 + ) + if err := s.SetPortRange(start, end); err != nil { + t.Fatalf("got s.SetPortRange(%d, %d) = %s, want = nil", start, end, err) + } for i := start; i <= end; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) @@ -5387,7 +5522,7 @@ func TestListenBacklogFull(t *testing.T) { for i := 0; i < listenBacklog; i++ { _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5404,7 +5539,7 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5416,7 +5551,7 @@ func TestListenBacklogFull(t *testing.T) { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5445,8 +5580,8 @@ func TestListenBacklogFull(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a // non unicast IPv4 address are not accepted. func TestListenNoAcceptNonUnicastV4(t *testing.T) { - multicastAddr := tcpip.Address("\xe0\x00\x01\x02") - otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + multicastAddr := tcpiptestutil.MustParse4("224.0.1.2") + otherMulticastAddr := tcpiptestutil.MustParse4("224.0.1.3") subnet := context.StackAddrWithPrefix.Subnet() subnetBroadcastAddr := subnet.Broadcast() @@ -5557,8 +5692,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { // TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a // non unicast IPv6 address are not accepted. func TestListenNoAcceptNonUnicastV6(t *testing.T) { - multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") - otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + multicastAddr := tcpiptestutil.MustParse6("ff0e::101") + otherMulticastAddr := tcpiptestutil.MustParse6("ff0e::102") tests := []struct { name string @@ -5671,15 +5806,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %s", err) } // Send two SYN's the first one should get a SYN-ACK, the // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. + // the accept queue is full. irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5701,23 +5834,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. + // Now complete the previous connection. // Send ACK. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5728,13 +5845,26 @@ func TestListenSynRcvdQueueFull(t *testing.T) { RcvWnd: 30000, }) - // Try to accept the connections in the backlog. + // Verify if that is delivered to the accept queue. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) + <-ch + + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + // Try to accept the connections in the backlog. newEP, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5764,11 +5894,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.TCPSynRcvdCountThresholdOption(1) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - // Create TCP endpoint. var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -5781,9 +5906,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { t.Fatalf("Bind failed: %s", err) } - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + // Test for SynCookies usage after filling up the backlog. + if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %s", err) } @@ -5811,7 +5935,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { defer c.WQ.EventUnregister(&we) _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -5827,7 +5951,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Now verify that there are no more connections that can be accepted. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { + if !cmp.Equal(&tcpip.ErrWouldBlock{}, err) { select { case <-ch: t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) @@ -5966,7 +6090,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { t.Fatalf("Accept failed: %s", err) } - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Try to accept the connections in the backlog. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) @@ -6034,7 +6158,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { // Verify that there is only one acceptable connection at this point. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6104,7 +6228,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { // Now check that there is one acceptable connections. _, _, err = c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6156,7 +6280,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) { defer wq.EventUnregister(&we) aep, _, err := ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6174,8 +6298,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) { } { err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) - if _, ok := err.(*tcpip.ErrAlreadyConnected); !ok { - t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %v, want: %s", err, &tcpip.ErrAlreadyConnected{}) + if d := cmp.Diff(&tcpip.ErrAlreadyConnected{}, err); d != "" { + t.Errorf("Connect(...) mismatch (-want +got):\n%s", d) } } // Listening endpoint remains in listen state. @@ -6295,7 +6419,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // window increases to the full available buffer size. for { _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } } @@ -6426,7 +6550,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalCopied := 0 for { res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { break } totalCopied += res.Count @@ -6618,7 +6742,7 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6737,7 +6861,7 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6844,7 +6968,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -6934,7 +7058,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Try to accept the connection. c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7008,7 +7132,7 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7158,7 +7282,7 @@ func TestTCPCloseWithData(t *testing.T) { defer wq.EventUnregister(&we) c.EP, _, err = ep.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); ok { + if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. select { case <-ch: @@ -7553,8 +7677,7 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { // Increasing the buffer from should generate an ACK, // since window grew from small value to larger equal MSS - c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) - + c.EP.SocketOptions().SetReceiveBufferSize(rcvBuf*2, true) checker.IPv4(t, c.GetPacket(), checker.PayloadLen(header.TCPMinimumSize), checker.TCP( @@ -7590,8 +7713,8 @@ func TestTCPDeferAccept(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Send data. This should result in an acceptable endpoint. @@ -7649,8 +7772,8 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) _, _, err := c.EP.Accept(nil) - if _, ok := err.(*tcpip.ErrWouldBlock); !ok { - t.Fatalf("got c.EP.Accept(nil) = %v, want: %s", err, &tcpip.ErrWouldBlock{}) + if d := cmp.Diff(&tcpip.ErrWouldBlock{}, err); d != "" { + t.Fatalf("c.EP.Accept(nil) mismatch (-want +got):\n%s", d) } // Sleep for a little of the tcpDeferAccept timeout. diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 2949588ce..1deb1fe4d 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e73f90bb0..53efecc5a 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -331,8 +331,8 @@ func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte { vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) b := vv.ToView() - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) + if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize) } checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) @@ -757,9 +757,7 @@ func (c *Context) Create(epRcvBuf int) { } if epRcvBuf != -1 { - if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */) } } @@ -1216,9 +1214,9 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { if enable { - c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.HWGSOSupported } else { - c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + c.linkEP.SupportedGSOKind = stack.GSONotSupported } } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 153e8c950..dd5c910ae 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 956da0e0c..f7dd50d35 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "fmt" "io" "sync/atomic" @@ -89,12 +88,11 @@ type endpoint struct { // The following fields are used to manage the receive queue, and are // protected by rcvMu. - rcvMu sync.Mutex `state:"nosave"` - rcvReady bool - rcvList udpPacketList - rcvBufSizeMax int `state:".(int)"` - rcvBufSize int - rcvClosed bool + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSize int + rcvClosed bool // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` @@ -144,6 +142,10 @@ type endpoint struct { // ops is used to get socket level options. ops tcpip.SocketOptions + + // frozen indicates if the packets should be delivered to the endpoint + // during restore. + frozen bool } // +stateify savable @@ -173,14 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - rcvBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) e.ops.SetMulticastLoop(true) e.ops.SetSendBufferSize(32*1024, false /* notify */) + e.ops.SetReceiveBufferSize(32*1024, false /* notify */) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -188,9 +190,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) } - var rs stack.ReceiveBufferSizeOption + var rs tcpip.ReceiveBufferSizeOption if err := s.Option(&rs); err == nil { - e.rcvBufSizeMax = rs.Default + e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) } return e @@ -622,26 +624,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { e.mu.Lock() e.sendTOS = uint8(v) e.mu.Unlock() - - case tcpip.ReceiveBufferSizeOption: - // Make sure the receive buffer size is within the min and max - // allowed. - var rs stack.ReceiveBufferSizeOption - if err := e.stack.Option(&rs); err != nil { - panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) - } - - if v < rs.Min { - v = rs.Min - } - if v > rs.Max { - v = rs.Max - } - - e.mu.Lock() - e.rcvBufSizeMax = v - e.mu.Unlock() - return nil } return nil @@ -802,12 +784,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - v := e.rcvBufSizeMax - e.rcvMu.Unlock() - return v, nil - case tcpip.TTLOption: e.mu.Lock() v := int(e.ttl) @@ -872,7 +848,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: ProtocolNumber, TTL: ttl, TOS: tos, @@ -1255,20 +1231,29 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // verifyChecksum verifies the checksum unless RX checksum offload is enabled. -// On IPv4, UDP checksum is optional, and a zero value means the transmitter -// omitted the checksum generation (RFC768). -// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { - if !pkt.RXTransportChecksumValidated && - (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { - netHdr := pkt.Network() - xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) - for _, v := range pkt.Data().Views() { - xsum = header.Checksum(v, xsum) - } - return hdr.CalculateChecksum(xsum) == 0xffff + if pkt.RXTransportChecksumValidated { + return true + } + + // On IPv4, UDP checksum is optional, and a zero value means the transmitter + // omitted the checksum generation, as per RFC 768: + // + // An all zero transmitted checksum value means that the transmitter + // generated no checksum (for debugging or for higher level protocols that + // don't care). + // + // On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1: + // + // Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP + // checksum is not optional. + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 { + return true } - return true + + netHdr := pkt.Network() + payloadChecksum := pkt.Data().AsRange().Checksum() + return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum) } // HandlePacket is called by the stack when new packets arrive to this transport @@ -1284,7 +1269,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } if !verifyChecksum(hdr, pkt) { - // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() return @@ -1302,7 +1286,8 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB return } - if e.rcvBufSize >= e.rcvBufSizeMax { + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { e.rcvMu.Unlock() e.stack.Stats().UDP.ReceiveBufferErrors.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() @@ -1436,3 +1421,18 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } + +// freeze prevents any more packets from being delivered to the endpoint. +func (e *endpoint) freeze() { + e.mu.Lock() + e.frozen = true + e.mu.Unlock() +} + +// thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows +// new packets to be delivered again. +func (e *endpoint) thaw() { + e.mu.Lock() + e.frozen = false + e.mu.Unlock() +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 21a6aa460..4aba68b21 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -37,43 +37,25 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) { u.data = data } -// beforeSave is invoked by stateify. -func (e *endpoint) beforeSave() { - // Stop incoming packets from being handled (and mutate endpoint state). - // The lock will be released after savercvBufSizeMax(), which would have - // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming - // packets. - e.rcvMu.Lock() -} - -// saveRcvBufSizeMax is invoked by stateify. -func (e *endpoint) saveRcvBufSizeMax() int { - max := e.rcvBufSizeMax - // Make sure no new packets will be handled regardless of the lock. - e.rcvBufSizeMax = 0 - // Release the lock acquired in beforeSave() so regular endpoint closing - // logic can proceed after save. - e.rcvMu.Unlock() - return max -} - -// loadRcvBufSizeMax is invoked by stateify. -func (e *endpoint) loadRcvBufSizeMax(max int) { - e.rcvBufSizeMax = max -} - // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { stack.StackFromEnv.RegisterRestoredEndpoint(e) } +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + e.freeze() +} + // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.thaw() + e.mu.Lock() defer e.mu.Unlock() e.stack = s - e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits) + e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 77ca70a04..dc2e3f493 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -2364,7 +2365,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } ipv4Subnet := ipv4Addr.Subnet() ipv4SubnetBcast := ipv4Subnet.Broadcast() - ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4Gateway := testutil.MustParse4("192.168.1.1") ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ Address: "\xc0\xa8\x01\x3a", PrefixLen: 31, |