diff options
Diffstat (limited to 'pkg/tcpip')
57 files changed, 2714 insertions, 724 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e07ebd153..ebc8d0209 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,6 +15,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/tcpip/iptables", "//pkg/waiter", @@ -29,7 +30,7 @@ go_test( ) go_test( - name = "timer_test", + name = "tcpip_x_test", size = "small", srcs = ["timer_test.go"], deps = [":tcpip"], diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..3df7d18d3 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..a2f44b496 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f15bf1f1..885d773b0 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) +// ControlMessagesChecker is a function to check a property of ancillary data. +type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) + // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. +func ReceiveTOS(want uint8) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTOS { + t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + } + if got := cm.TOS; got != want { + t.Fatalf("got cm.TOS = %d, want %d", got, want) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -754,3 +770,9 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } } + +// NDPRS creates a checker that checks that the packet contains a valid NDP +// Router Solicitation message (as per the raw wire format). +func NDPRS() NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize) +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index f2061c778..cd747d100 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -20,6 +20,7 @@ go_library( "ndp_neighbor_solicit.go", "ndp_options.go", "ndp_router_advert.go", + "ndp_router_solicit.go", "tcp.go", "udp.go", ], diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 135a60b12..70e6ce095 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -84,6 +84,13 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6AllRoutersMulticastAddress is a link-local multicast group that + // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers on a link. + // + // The address is ff02::2. + IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 @@ -333,6 +340,17 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6UniqueLocalAddress determines if the provided address is an IPv6 +// unique-local address (within the prefix FC00::/7). +func IsV6UniqueLocalAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + // According to RFC 4193 section 3.1, a unique local address has the prefix + // FC00::/7. + return (addr[0] & 0xfe) == 0xfc +} + // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier // (IID) to buf as outlined by RFC 7217 and returns the extended buffer. // @@ -371,3 +389,35 @@ func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []by return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) } + +// IPv6AddressScope is the scope of an IPv6 address. +type IPv6AddressScope int + +const ( + // LinkLocalScope indicates a link-local address. + LinkLocalScope IPv6AddressScope = iota + + // UniqueLocalScope indicates a unique-local address. + UniqueLocalScope + + // GlobalScope indicates a global address. + GlobalScope +) + +// ScopeForIPv6Address returns the scope for an IPv6 address. +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { + if len(addr) != IPv6AddressSize { + return GlobalScope, tcpip.ErrBadAddress + } + + switch { + case IsV6LinkLocalAddress(addr): + return LinkLocalScope, nil + + case IsV6UniqueLocalAddress(addr): + return UniqueLocalScope, nil + + default: + return GlobalScope, nil + } +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 1994003ed..29f54bc57 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -25,7 +25,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") +) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} @@ -206,3 +212,91 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { }) } } + +func TestIsV6UniqueLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Unique 1", + addr: uniqueLocalAddr1, + expected: true, + }, + { + name: "Valid Unique 2", + addr: uniqueLocalAddr1, + expected: true, + }, + { + name: "Link Local", + addr: linkLocalAddr, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4", + addr: "\x01\x02\x03\x04", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestScopeForIPv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + scope header.IPv6AddressScope + err *tcpip.Error + }{ + { + name: "Unique Local", + addr: uniqueLocalAddr1, + scope: header.UniqueLocalScope, + err: nil, + }, + { + name: "Link Local", + addr: linkLocalAddr, + scope: header.LinkLocalScope, + err: nil, + }, + { + name: "Global", + addr: globalAddr, + scope: header.GlobalScope, + err: nil, + }, + { + name: "IPv4", + addr: "\x01\x02\x03\x04", + scope: header.GlobalScope, + err: tcpip.ErrBadAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := header.ScopeForIPv6Address(test.addr) + if err != test.err { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) + } + if got != test.scope { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) + } + }) + } +} diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go new file mode 100644 index 000000000..9e67ba95d --- /dev/null +++ b/pkg/tcpip/header/ndp_router_solicit.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.1 for more details. +type NDPRouterSolicit []byte + +const ( + // NDPRSMinimumSize is the minimum size of a valid NDP Router + // Solicitation message (body of an ICMPv6 packet). + NDPRSMinimumSize = 4 + + // ndpRSOptionsOffset is the start of the NDP options in an + // NDPRouterSolicit. + ndpRSOptionsOffset = 4 +) + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterSolicit) Options() NDPOptions { + return NDPOptions(b[ndpRSOptionsOffset:]) +} diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index cc5f531e2..64769c333 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -11,5 +11,8 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], - deps = ["//pkg/tcpip/buffer"], + deps = [ + "//pkg/log", + "//pkg/tcpip/buffer", + ], ) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 68c68d4aa..647970133 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -16,66 +16,114 @@ // tool. package iptables +// Table names. const ( - tablenameNat = "nat" - tablenameMangle = "mangle" + TablenameNat = "nat" + TablenameMangle = "mangle" + TablenameFilter = "filter" ) // Chain names as defined by net/ipv4/netfilter/ip_tables.c. const ( - chainNamePrerouting = "PREROUTING" - chainNameInput = "INPUT" - chainNameForward = "FORWARD" - chainNameOutput = "OUTPUT" - chainNamePostrouting = "POSTROUTING" + ChainNamePrerouting = "PREROUTING" + ChainNameInput = "INPUT" + ChainNameForward = "FORWARD" + ChainNameOutput = "OUTPUT" + ChainNamePostrouting = "POSTROUTING" ) +// HookUnset indicates that there is no hook set for an entrypoint or +// underflow. +const HookUnset = -1 + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() IPTables { + // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for + // iotas. return IPTables{ Tables: map[string]Table{ - tablenameNat: Table{ - BuiltinChains: map[Hook]Chain{ - Prerouting: unconditionalAcceptChain(chainNamePrerouting), - Input: unconditionalAcceptChain(chainNameInput), - Output: unconditionalAcceptChain(chainNameOutput), - Postrouting: unconditionalAcceptChain(chainNamePostrouting), + TablenameNat: Table{ + Rules: []Rule{ + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, }, - DefaultTargets: map[Hook]Target{ - Prerouting: UnconditionalAcceptTarget{}, - Input: UnconditionalAcceptTarget{}, - Output: UnconditionalAcceptTarget{}, - Postrouting: UnconditionalAcceptTarget{}, + Underflows: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, }, - UserChains: map[string]Chain{}, + UserChains: map[string]int{}, }, - tablenameMangle: Table{ - BuiltinChains: map[Hook]Chain{ - Prerouting: unconditionalAcceptChain(chainNamePrerouting), - Output: unconditionalAcceptChain(chainNameOutput), + TablenameMangle: Table{ + Rules: []Rule{ + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Output: 1, }, - DefaultTargets: map[Hook]Target{ - Prerouting: UnconditionalAcceptTarget{}, - Output: UnconditionalAcceptTarget{}, + Underflows: map[Hook]int{ + Prerouting: 0, + Output: 1, }, - UserChains: map[string]Chain{}, + UserChains: map[string]int{}, + }, + TablenameFilter: Table{ + Rules: []Rule{ + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + Underflows: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + UserChains: map[string]int{}, }, }, Priorities: map[Hook][]string{ - Prerouting: []string{tablenameMangle, tablenameNat}, - Output: []string{tablenameMangle, tablenameNat}, + Input: []string{TablenameNat, TablenameFilter}, + Prerouting: []string{TablenameMangle, TablenameNat}, + Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, } } -func unconditionalAcceptChain(name string) Chain { - return Chain{ - Name: name, - Rules: []Rule{ - Rule{ - Target: UnconditionalAcceptTarget{}, - }, +// EmptyFilterTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + Underflows: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, }, + UserChains: map[string]int{}, } } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 19a7f77e3..b94a4c941 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -16,7 +16,10 @@ package iptables -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import ( + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) // UnconditionalAcceptTarget accepts all packets. type UnconditionalAcceptTarget struct{} @@ -33,3 +36,14 @@ type UnconditionalDropTarget struct{} func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) { return Drop, "" } + +// ErrorTarget logs an error and drops the packet. It represents a target that +// should be unreachable. +type ErrorTarget struct{} + +// Action implements Target.Action. +func (ErrorTarget) Action(packet buffer.VectorisedView) (Verdict, string) { + log.Warningf("ErrorTarget triggered.") + return Drop, "" + +} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 42a79ef9f..540f8c0b4 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -61,9 +61,12 @@ const ( type Verdict int const ( + // Invalid indicates an unkonwn or erroneous verdict. + Invalid Verdict = iota + // Accept indicates the packet should continue traversing netstack as // normal. - Accept Verdict = iota + Accept // Drop inicates the packet should be dropped, stopping traversing // netstack. @@ -104,29 +107,22 @@ type IPTables struct { Priorities map[Hook][]string } -// A Table defines a set of chains and hooks into the network stack. The -// currently supported tables are: -// * nat -// * mangle +// A Table defines a set of chains and hooks into the network stack. It is +// really just a list of rules with some metadata for entrypoints and such. type Table struct { - // BuiltinChains holds the un-deletable chains built into netstack. If - // a hook isn't present in the map, this table doesn't utilize that - // hook. - BuiltinChains map[Hook]Chain + // Rules holds the rules that make up the table. + Rules []Rule - // DefaultTargets holds a target for each hook that will be executed if - // chain traversal doesn't yield a verdict. - DefaultTargets map[Hook]Target + // BuiltinChains maps builtin chains to their entrypoint rule in Rules. + BuiltinChains map[Hook]int + + // Underflows maps builtin chains to their underflow rule in Rules + // (i.e. the rule to execute if the chain returns without a verdict). + Underflows map[Hook]int // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. - UserChains map[string]Chain - - // Chains maps names to chains for both builtin and user-defined chains. - // Its entries point to Chains already either in BuiltinChains or - // UserChains, and its purpose is to make looking up tables by name - // fast. - Chains map[string]*Chain + UserChains map[string]int // Metadata holds information about the Table that is useful to users // of IPTables, but not to the netstack IPTables code itself. @@ -152,21 +148,6 @@ func (table *Table) SetMetadata(metadata interface{}) { table.metadata = metadata } -// A Chain defines a list of rules for packet processing. When a packet -// traverses a chain, it is checked against each rule until either a rule -// returns a verdict or the chain ends. -// -// By convention, builtin chains end with a rule that matches everything and -// returns either Accept or Drop. User-defined chains end with Return. These -// aren't strictly necessary here, but the iptables tool writes tables this way. -type Chain struct { - // Name is the chain name. - Name string - - // Rules is the list of rules to traverse. - Rules []Rule -} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 897c94821..66cc53ed4 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -16,6 +16,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index fa8a703d9..b7f60178e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,10 +41,10 @@ package fdbased import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index a4f9cdd69..09165dd4c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,6 +32,7 @@ go_test( ], embed = [":sharedmem"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 6b5bc542c..a0d4ad0be 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -21,4 +21,5 @@ go_test( "pipe_test.go", ], embed = [":pipe"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index 59ef69a8b..dc239a0d0 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -18,8 +18,9 @@ import ( "math/rand" "reflect" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSimpleReadWrite(t *testing.T) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 080f9d667..655e537c4 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -23,11 +23,11 @@ package sharedmem import ( - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 89603c48f..5c729a439 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -22,11 +22,11 @@ import ( "math/rand" "os" "strings" - "sync" "syscall" "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index acf1e022c..ed16076fd 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", ], diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6da5238ec..92f2aa13a 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -19,9 +19,9 @@ package fragmentation import ( "fmt" "log" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9e002e396..0a83d81f2 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index e156b01f6..a6ef3bdcc 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 6c5e19e8f..b937cb84b 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -18,9 +18,9 @@ package ports import ( "math" "math/rand" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index b8f9517d0..783351a69 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( "//pkg/ilist", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", @@ -50,7 +51,7 @@ go_library( go_test( name = "stack_x_test", - size = "small", + size = "medium", srcs = [ "ndp_test.go", "stack_test.go", @@ -83,6 +84,7 @@ go_test( embed = [":stack"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..403557fd7 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..1baa498d0 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 35825ebf7..c99d387d5 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "log" + "math/rand" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,24 +39,36 @@ const ( // Default = 1s (from RFC 4861 section 10). defaultRetransmitTimer = time.Second + // defaultMaxRtrSolicitations is the default number of Router + // Solicitation messages to send when a NIC becomes enabled. + // + // Default = 3 (from RFC 4861 section 10). + defaultMaxRtrSolicitations = 3 + + // defaultRtrSolicitationInterval is the default amount of time between + // sending Router Solicitation messages. + // + // Default = 4s (from 4861 section 10). + defaultRtrSolicitationInterval = 4 * time.Second + + // defaultMaxRtrSolicitationDelay is the default maximum amount of time + // to wait before sending the first Router Solicitation message. + // + // Default = 1s (from 4861 section 10). + defaultMaxRtrSolicitationDelay = time.Second + // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - // - // Default = true. defaultHandleRAs = true // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router // Advertisements, as a host. - // - // Default = true. defaultDiscoverDefaultRouters = true // defaultDiscoverOnLinkPrefixes is the default configuration for // whether or not to discover on-link prefixes from incoming Router // Advertisements' Prefix Information option, as a host. - // - // Default = true. defaultDiscoverOnLinkPrefixes = true // defaultAutoGenGlobalAddresses is the default configuration for @@ -74,26 +87,31 @@ const ( // value of 0 means unspecified, so the smallest valid value is 1. // Note, the unit of the RetransmitTimer field in the Router // Advertisement is milliseconds. - // - // Min = 1ms. minimumRetransmitTimer = time.Millisecond + // minimumRtrSolicitationInterval is the minimum amount of time to wait + // between sending Router Solicitation messages. This limit is imposed + // to make sure that Router Solicitation messages are not sent all at + // once, defeating the purpose of sending the initial few messages. + minimumRtrSolicitationInterval = 500 * time.Millisecond + + // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait + // before sending the first Router Solicitation message. It is 0 because + // we cannot have a negative delay. + minimumMaxRtrSolicitationDelay = 0 + // MaxDiscoveredDefaultRouters is the maximum number of discovered // default routers. The stack should stop discovering new routers after // discovering MaxDiscoveredDefaultRouters routers. // // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and // SHOULD be more. - // - // Max = 10. MaxDiscoveredDefaultRouters = 10 // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered // on-link prefixes. The stack should stop discovering new on-link // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link // prefixes. - // - // Max = 10. MaxDiscoveredOnLinkPrefixes = 10 // validPrefixLenForAutoGen is the expected prefix length that an @@ -115,6 +133,30 @@ var ( MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) +// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an +// NDP Router Advertisement informed the Stack about. +type DHCPv6ConfigurationFromNDPRA int + +const ( + // DHCPv6NoConfiguration indicates that no configurations are available via + // DHCPv6. + DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + + // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. + // + // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 + // will return all available configuration information. + DHCPv6ManagedAddress + + // DHCPv6OtherConfigurations indicates that other configuration information is + // available via DHCPv6. + // + // Other configurations are configurations other than addresses. Examples of + // other configurations are recursive DNS server list, DNS search lists and + // default gateway. + DHCPv6OtherConfigurations +) + // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { @@ -194,7 +236,20 @@ type NDPDispatcher interface { // already known DNS servers. If called with known DNS servers, their // valid lifetimes must be refreshed to lifetime (it may be increased, // decreased, or completely invalidated when lifetime = 0). + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + + // OnDHCPv6Configuration will be called with an updated configuration that is + // available via DHCPv6 for a specified NIC. + // + // NDPDispatcher assumes that the initial configuration available by DHCPv6 is + // DHCPv6NoConfiguration. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } // NDPConfigurations is the NDP configurations for the netstack. @@ -208,9 +263,24 @@ type NDPConfigurations struct { // The amount of time to wait between sending Neighbor solicitation // messages. // - // Must be greater than 0.5s. + // Must be greater than or equal to 1ms. RetransmitTimer time.Duration + // The number of Router Solicitation messages to send when the NIC + // becomes enabled. + MaxRtrSolicitations uint8 + + // The amount of time between transmitting Router Solicitation messages. + // + // Must be greater than or equal to 0.5s. + RtrSolicitationInterval time.Duration + + // The maximum amount of time before transmitting the first Router + // Solicitation message. + // + // Must be greater than or equal to 0s. + MaxRtrSolicitationDelay time.Duration + // HandleRAs determines whether or not Router Advertisements will be // processed. HandleRAs bool @@ -241,12 +311,15 @@ type NDPConfigurations struct { // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, - HandleRAs: defaultHandleRAs, - DiscoverDefaultRouters: defaultDiscoverDefaultRouters, - DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, - AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -255,10 +328,24 @@ func DefaultNDPConfigurations() NDPConfigurations { // // If RetransmitTimer is less than minimumRetransmitTimer, then a value of // defaultRetransmitTimer will be used. +// +// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then +// a value of defaultRtrSolicitationInterval will be used. +// +// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then +// a value of defaultMaxRtrSolicitationDelay will be used. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer } + + if c.RtrSolicitationInterval < minimumRtrSolicitationInterval { + c.RtrSolicitationInterval = defaultRtrSolicitationInterval + } + + if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { + c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay + } } // ndpState is the per-interface NDP state. @@ -279,8 +366,15 @@ type ndpState struct { // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + // The timer used to send the next router solicitation message. + // If routers are being solicited, rtrSolicitTimer MUST NOT be nil. + rtrSolicitTimer *time.Timer + // The addresses generated by SLAAC. autoGenAddresses map[tcpip.Address]autoGenAddressState + + // The last learned DHCPv6 configuration from an NDP RA. + dhcpv6Configuration DHCPv6ConfigurationFromNDPRA } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -461,10 +555,12 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u // address. panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) } + snmcRef.incRef() // Use the unspecified address as the source address when performing // DAD. r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) @@ -533,6 +629,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { return } + // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we + // only inform the dispatcher on configuration changes. We do nothing else + // with the information. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + var configuration DHCPv6ConfigurationFromNDPRA + switch { + case ra.ManagedAddrConfFlag(): + configuration = DHCPv6ManagedAddress + + case ra.OtherConfFlag(): + configuration = DHCPv6OtherConfigurations + + default: + configuration = DHCPv6NoConfiguration + } + + if ndp.dhcpv6Configuration != configuration { + ndp.dhcpv6Configuration = configuration + ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + } + } + // Is the NIC configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] @@ -876,7 +994,7 @@ func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration // If the preferred lifetime is zero, then the address should be considered // deprecated. deprecated := pl == 0 - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + ref, err := ndp.nic.addPermanentAddressLocked(protocolAddr, FirstPrimaryEndpoint, slaac, deprecated) if err != nil { log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) } @@ -1070,3 +1188,84 @@ func (ndp *ndpState) cleanupHostOnlyState() { log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) } } + +// startSolicitingRouters starts soliciting routers, as per RFC 4861 section +// 6.3.7. If routers are already being solicited, this function does nothing. +// +// The NIC ndp belongs to MUST be locked. +func (ndp *ndpState) startSolicitingRouters() { + if ndp.rtrSolicitTimer != nil { + // We are already soliciting routers. + return + } + + remaining := ndp.configs.MaxRtrSolicitations + if remaining == 0 { + return + } + + // Calculate the random delay before sending our first RS, as per RFC + // 4861 section 6.3.7. + var delay time.Duration + if ndp.configs.MaxRtrSolicitationDelay > 0 { + delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + } + + ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + // Send an RS message with the unspecified source address. + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + defer r.Release() + + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) + pkt := header.ICMPv6(hdr.Prepend(payloadSize)) + pkt.SetType(header.ICMPv6RouterSolicit) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, tcpip.PacketBuffer{Header: hdr}, + ); err != nil { + sent.Dropped.Increment() + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) + // Don't send any more messages if we had an error. + remaining = 0 + } else { + sent.RouterSolicit.Increment() + remaining-- + } + + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if remaining == 0 { + ndp.rtrSolicitTimer = nil + } else if ndp.rtrSolicitTimer != nil { + // Note, we need to explicitly check to make sure that + // the timer field is not nil because if it was nil but + // we still reached this point, then we know the NIC + // was requested to stop soliciting routers so we don't + // need to send the next Router Solicitation message. + ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + } + }) + +} + +// stopSolicitingRouters stops soliciting routers. If routers are not currently +// being solicited, this function does nothing. +// +// The NIC ndp belongs to MUST be locked. +func (ndp *ndpState) stopSolicitingRouters() { + if ndp.rtrSolicitTimer == nil { + // Nothing to do. + return + } + + ndp.rtrSolicitTimer.Stop() + ndp.rtrSolicitTimer = nil +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d334af289..1a52e0e68 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -35,12 +35,12 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - linkAddr1 = "\x02\x02\x03\x04\x05\x06" - linkAddr2 = "\x02\x02\x03\x04\x05\x07" - linkAddr3 = "\x02\x02\x03\x04\x05\x08" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") defaultTimeout = 100 * time.Millisecond ) @@ -162,18 +162,24 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDHCPv6Event struct { + nicID tcpip.NICID + configuration stack.DHCPv6ConfigurationFromNDPRA +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent + dhcpv6ConfigurationC chan ndpDHCPv6Event } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -280,6 +286,16 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { + if c := n.dhcpv6ConfigurationC; c != nil { + c <- ndpDHCPv6Event{ + nicID, + configuration, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -797,21 +813,32 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options +// and DHCPv6 configurations specified. +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - ra := header.NDPRouterAdvert(pkt.NDPPayload()) + raPayload := pkt.NDPPayload() + ra := header.NDPRouterAdvert(raPayload) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(raPayload[2:], rl) + // Populate the Managed Address flag field. + if managedAddress { + // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) + // of the RA payload. + raPayload[1] |= (1 << 7) + } + // Populate the Other Configurations flag field. + if otherConfigurations { + // The Other Configurations flag field is the 6th bit of byte #1 + // (0-indexing) of the RA payload. + raPayload[1] |= (1 << 6) + } opts := ra.Options() opts.Serialize(optSer) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -826,6 +853,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} } +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) +} + +// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related +// fields set. +// +// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the +// DHCPv6 related ones. +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +} + // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the @@ -1688,9 +1732,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd return ndpDisp, e, s } -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { +// addrForNewConnectionTo returns the local address used when creating a new +// connection to addr. +func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -1704,8 +1750,8 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + if err := ep.Connect(addr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", addr, err) } got, err := ep.GetLocalAddress() if err != nil { @@ -1714,9 +1760,19 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { return got.Addr } +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + t.Helper() + + return addrForNewConnectionTo(t, s, dstAddr) +} + // addrForNewConnectionWithAddr returns the local address used when creating a // new connection with a specific local address. func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -2389,6 +2445,119 @@ func TestAutoGenAddrRemoval(t *testing.T) { } } +// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously +// assigned to the NIC but is in the permanentExpired state. +func TestAutoGenAddrAfterRemoval(t *testing.T) { + t.Parallel() + + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) + + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) + + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) + + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) + + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) +} + // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that // is already assigned to the NIC, the static address remains. func TestAutoGenAddrStaticConflict(t *testing.T) { @@ -2951,3 +3120,318 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { default: } } + +// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly +// informed when new information about what configurations are available via +// DHCPv6 is learned. +func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + t.Helper() + select { + case e := <-ndpDisp.dhcpv6ConfigurationC: + if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { + t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DHCPv6 configuration event") + } + } + + expectNoDHCPv6Event := func() { + t.Helper() + select { + case <-ndpDisp.dhcpv6ConfigurationC: + t.Fatal("unexpected DHCPv6 configuration event") + default: + } + } + + // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + // Receiving the same update again should not result in an event to the + // NDPDispatcher. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to none. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + // + // Note, when the M flag is set, the O flag is redundant. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + // Even though the DHCPv6 flags are different, the effective configuration is + // the same so we should not receive a new event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() +} + +// TestRouterSolicitation tests the initial Router Solicitations that are sent +// when a NIC newly becomes enabled. +func TestRouterSolicitation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + maxRtrSolicit uint8 + rtrSolicitInt time.Duration + effectiveRtrSolicitInt time.Duration + maxRtrSolicitDelay time.Duration + effectiveMaxRtrSolicitDelay time.Duration + }{ + { + name: "Single RS with delay", + maxRtrSolicit: 1, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: time.Second, + effectiveMaxRtrSolicitDelay: time.Second, + }, + { + name: "Two RS with delay", + maxRtrSolicit: 2, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: 500 * time.Millisecond, + effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, + }, + { + name: "Single RS without delay", + maxRtrSolicit: 1, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Two RS without delay and invalid zero interval", + maxRtrSolicit: 2, + rtrSolicitInt: 0, + effectiveRtrSolicitInt: 4 * time.Second, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Three RS without delay", + maxRtrSolicit: 3, + rtrSolicitInt: 500 * time.Millisecond, + effectiveRtrSolicitInt: 500 * time.Millisecond, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Two RS with invalid negative delay", + maxRtrSolicit: 2, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: -3 * time.Second, + effectiveMaxRtrSolicitDelay: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + select { + case p := <-e.C: + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) + + case <-time.After(timeout): + t.Fatal("timed out waiting for packet") + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() + select { + case <-e.C: + t.Fatal("unexpectedly got a packet") + case <-time.After(timeout): + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Make sure each RS got sent at the right + // times. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + remaining-- + } + for ; remaining > 0; remaining-- { + waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) + waitForPkt(2 * defaultTimeout) + } + + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + } + + // Make sure the counter got properly + // incremented. + if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } + }) +} + +// TestStopStartSolicitingRouters tests that when forwarding is enabled or +// disabled, router solicitations are stopped or started, respecitively. +func TestStopStartSolicitingRouters(t *testing.T) { + t.Parallel() + + const interval = 500 * time.Millisecond + const delay = time.Second + const maxRtrSolicitations = 3 + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + select { + case p := <-e.C: + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + + case <-time.After(timeout): + t.Fatal("timed out waiting for packet") + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Enable forwarding which should stop router solicitations. + s.SetForwarding(true) + select { + case <-e.C: + // A single RS may have been sent before forwarding was enabled. + select { + case <-e.C: + t.Fatal("Should not have sent more than one RS message") + case <-time.After(interval + defaultTimeout): + } + case <-time.After(delay + defaultTimeout): + } + + // Enabling forwarding again should do nothing. + s.SetForwarding(true) + select { + case <-e.C: + t.Fatal("unexpectedly got a packet after becoming a router") + case <-time.After(delay + defaultTimeout): + } + + // Disable forwarding which should start router solicitations. + s.SetForwarding(false) + waitForPkt(delay + defaultTimeout) + waitForPkt(interval + defaultTimeout) + waitForPkt(interval + defaultTimeout) + select { + case <-e.C: + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + case <-time.After(interval + defaultTimeout): + } + + // Disabling forwarding again should do nothing. + s.SetForwarding(false) + select { + case <-e.C: + t.Fatal("unexpectedly got a packet after becoming a router") + case <-time.After(delay + defaultTimeout): + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4144d5d0f..4452a1302 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,10 +15,12 @@ package stack import ( + "log" + "sort" "strings" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -27,10 +29,11 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + context NICContext mu sync.RWMutex spoofing bool @@ -84,7 +87,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -98,6 +101,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { id: id, name: name, linkEP: ep, + context: ctx, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -173,49 +177,72 @@ func (n *NIC) enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() { - return nil - } + if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { + var addr tcpip.Address + if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) + } else { + l2addr := n.linkEP.LinkAddress() - var addr tcpip.Address - if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) - } else { - l2addr := n.linkEP.LinkAddress() + // Only attempt to generate the link-local address if we have a valid MAC + // address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } - // Only attempt to generate the link-local address if we have a valid MAC - // address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil + addr = header.LinkLocalAddr(l2addr) } - addr = header.LinkLocalAddr(l2addr) + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + }, + }, CanBePrimaryEndpoint, static, false /* deprecated */); err != nil { + return err + } } - _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - }, - }, CanBePrimaryEndpoint) + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyways. + // + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !n.stack.forwarding { + n.ndp.startSolicitingRouters() + } - return err + return nil } // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered // routers, discovered on-link prefixes, and auto-generated addresses) will -// be cleaned up/invalidated. +// be cleaned up/invalidated and NDP router solicitations will be stopped. func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() n.ndp.cleanupHostOnlyState() + n.ndp.stopSolicitingRouters() +} + +// becomeIPv6Host transitions n into an IPv6 host. +// +// When transitioning into an IPv6 host, NDP router solicitations will be +// started. +func (n *NIC) becomeIPv6Host() { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.startSolicitingRouters() } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -249,13 +276,17 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint returns the primary endpoint of n for the given network -// protocol. -// // primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists. If no non-deprecated endpoint exists, the first deprecated -// endpoint will be returned. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { +// endpoint exists for the given protocol and remoteAddr. If no non-deprecated +// endpoint exists, the first deprecated endpoint will be returned. +// +// If an IPv6 primary endpoint is requested, Source Address Selection (as +// defined by RFC 6724 section 5) will be performed. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { + if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { + return n.primaryIPv6Endpoint(remoteAddr) + } + n.mu.RLock() defer n.mu.RUnlock() @@ -294,6 +325,103 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return deprecatedEndpoint } +// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC +// 6724 section 5). +type ipv6AddrCandidate struct { + ref *referencedNetworkEndpoint + scope header.IPv6AddressScope +} + +// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 are followed. +// +// remoteAddr must be a valid IPv6 address. +func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { + n.mu.RLock() + defer n.mu.RUnlock() + + primaryAddrs := n.primary[header.IPv6ProtocolNumber] + + if len(primaryAddrs) == 0 { + return nil + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) + for _, r := range primaryAddrs { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !r.isValidForOutgoing() { + continue + } + + addr := r.ep.ID().LocalAddress + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err) + } + + cs = append(cs, ipv6AddrCandidate{ + ref: r, + scope: scope, + }) + } + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.ref.ep.ID().LocalAddress == remoteAddr { + return true + } + if sb.ref.ep.ID().LocalAddress == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if r := c.ref; r.tryIncRef() { + return r + } + } + + return nil +} + // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { @@ -405,7 +533,12 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } -func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { +// addPermanentAddressLocked adds a permanent address to n. +// +// If n already has the address in a non-permanent state, +// addPermanentAddressLocked will promote it to permanent and update the +// endpoint with the properties provided. +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { switch ref.getKind() { @@ -413,10 +546,14 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect - // the new peb. + // Promote the endpoint to become permanent and respect the new peb, + // configType and deprecated status. if ref.tryIncRef() { + // TODO(b/147748385): Perform Duplicate Address Detection when promoting + // an IPv6 endpoint to permanent. ref.setKind(permanent) + ref.deprecated = deprecated + ref.configType = configType refs := n.primary[ref.protocol] for i, r := range refs { @@ -448,9 +585,13 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent, static, false) + return n.addAddressLocked(protocolAddress, peb, permanent, configType, deprecated) } +// addAddressLocked adds a new protocolAddress to n. +// +// If the address is already known by n (irrespective of the state it is in), +// addAddressLocked does nothing and returns tcpip.ErrDuplicateAddress. func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. @@ -525,7 +666,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addPermanentAddressLocked(protocolAddress, peb) + _, err := n.addPermanentAddressLocked(protocolAddress, peb, static, false /* deprecated */) n.mu.Unlock() return err @@ -658,7 +799,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// Subnets returns the Subnets associated with this NIC. +// AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() @@ -807,7 +948,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint); err != nil { + }, NeverPrimaryEndpoint, static, false /* deprecated */); err != nil { return err } } @@ -1185,7 +1326,8 @@ type referencedNetworkEndpoint struct { kind networkEndpointKind // configType is the method that was used to configure this endpoint. - // This must never change after the endpoint is added to a NIC. + // This must never change except during endpoint creation and promotion to + // permanent. configType networkEndpointConfigType // deprecated indicates whether or not the endpoint should be considered diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index fb7ac409e..fc56a6d79 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,13 +21,13 @@ package stack import ( "encoding/binary" - "sync" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -547,6 +547,49 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } +// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address +// and returns the network protocol number to be used to communicate with the +// specified address. It returns an error if the passed address is incompatible +// with the receiver. +func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto + switch len(addr.Addr) { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + if header.IsV4MappedAddress(addr.Addr) { + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == header.IPv4Any { + addr.Addr = "" + } + } + } + + switch len(e.ID.LocalAddress) { + case header.IPv4AddressSize: + if len(addr.Addr) == header.IPv6AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + case header.IPv6AddressSize: + if len(addr.Addr) == header.IPv4AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + } + } + + switch { + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + if v6only { + return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + } + default: + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + + return addr, netProto, nil +} + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo // marker interface. func (*TransportEndpointInfo) IsEndpointInfo() {} @@ -707,7 +750,9 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables the packet forwarding between NICs. // // When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up. +// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. +// When forwarding becomes disabled and if IPv6 is enabled, NDP Router +// Solicitations will be stopped. func (s *Stack) SetForwarding(enable bool) { // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. s.mu.Lock() @@ -729,6 +774,10 @@ func (s *Stack) SetForwarding(enable bool) { for _, nic := range s.nics { nic.becomeIPv6Router() } + } else { + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } @@ -796,6 +845,9 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } +// NICContext is an opaque pointer used to store client-supplied NIC metadata. +type NICContext interface{} + // NICOptions specifies the configuration of a NIC as it is being created. // The zero value creates an enabled, unnamed NIC. type NICOptions struct { @@ -805,6 +857,12 @@ type NICOptions struct { // Disabled specifies whether to avoid calling Attach on the passed // LinkEndpoint. Disabled bool + + // Context specifies user-defined data that will be returned in stack.NICInfo + // for the NIC. Clients of this library can use it to add metadata that + // should be tracked alongside a NIC, to avoid having to keep a + // map[tcpip.NICID]metadata mirroring stack.Stack's nic map. + Context NICContext } // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and @@ -819,7 +877,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, opts.Name, ep) + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n if !opts.Disabled { @@ -860,7 +918,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICSubnets returns a map of NICIDs to their associated subnets. +// NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() @@ -886,6 +944,18 @@ type NICInfo struct { MTU uint32 Stats NICStats + + // Context is user-supplied data optionally supplied in CreateNICWithOptions. + // See type NICOptions for more details. + Context NICContext +} + +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok } // NICInfo returns a map of NICIDs to their associated information. @@ -908,6 +978,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, + Context: nic.context, } } return nics @@ -1041,9 +1112,9 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { if len(localAddr) == 0 { - return nic.primaryEndpoint(netProto) + return nic.primaryEndpoint(netProto, remoteAddr) } return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } @@ -1059,7 +1130,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } @@ -1069,7 +1140,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } if nic, ok := s.nics[route.NIC]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9ac50bb23..4b3d18f1b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( @@ -2001,6 +2002,46 @@ func TestNICAutoGenAddr(t *testing.T) { } } +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) + } +} + // TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local // addresses with opaque interface identifiers. Link Local addresses should // always be generated with opaque IIDs if configured to use them, even if the @@ -2371,3 +2412,154 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { } } } + +func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { + const ( + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 + ) + + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. + tests := []struct { + name string + nicAddrs []tcpip.Address + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address + }{ + // Test Rule 1 of RFC 6724 section 5. + { + name: "Same Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Same Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test Rule 2 of RFC 6724 section 5. + { + name: "Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test returning the endpoint that is closest to the front when + // candidate addresses are "equal" from the perspective of RFC 6724 + // section 5. + { + name: "Unique Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Link Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local for Unique Local", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + + for _, a := range test.nicAddrs { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { + t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + } + } + + if t.Failed() { + t.FailNow() + } + + if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 67c21be42..d686e6eb8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -18,8 +18,8 @@ import ( "fmt" "math/rand" "sort" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + transEP := selectEndpoint(id, mpep, epsByNic.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNic.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } @@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { epsByNic.mu.Lock() defer epsByNic.mu.Unlock() @@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort } // This is a new binding. - multiPortEp := &multiPortEndpoint{} + multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} multiPortEp.endpointsMap = make(map[TransportEndpoint]int) multiPortEp.reuse = reusePort epsByNic.endpoints[bindToDevice] = multiPortEp @@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ endpoints: make(map[TransportEndpointID]*endpointsByNic), } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto + } } } @@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. @@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] for i, endpoint := range ep.endpointsArr { // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + break + } endpoint.HandlePacket(r, id, pkt) break } + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + continue + } endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. @@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if epsByNic, ok := eps.endpoints[id]; ok { // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // This is a new binding. @@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol } eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index df5ced887..5e9237de9 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -41,7 +41,7 @@ const ( type testContext struct { t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack ep tcpip.Endpoint @@ -66,27 +66,24 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - opts := stack.NICOptions{Name: linkEpName} - if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -105,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) return &testContext{ t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } @@ -122,7 +119,7 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -153,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -183,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) { func TestDistribution(t *testing.T) { type endpointSockopts struct { reuse int - bindToDevice string + bindToDevice tcpip.NICID } for _, test := range []struct { name string @@ -191,71 +188,71 @@ func TestDistribution(t *testing.T) { endpoints []endpointSockopts // wantedDistribution is the wanted ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantedDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": {0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, "dev0"}, - {0, "dev1"}, - {0, "dev2"}, + {0, 1}, + {0, 2}, + {0, 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": {1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": {0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": {0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, "dev0"}, - {1, "dev0"}, - {1, "dev1"}, - {1, "dev1"}, - {1, "dev1"}, - {1, ""}, + {1, 1}, + {1, 1}, + {1, 2}, + {1, 2}, + {1, 2}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": {0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": {0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { t.Run(test.name, func(t *testing.T) { for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string + t.Run(string(device), func(t *testing.T) { + var devices []tcpip.NICID for d := range test.wantedDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) defer c.cleanup() c.createV6Endpoint(false) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 1eca76c30..b7813cbc0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -35,10 +35,10 @@ import ( "reflect" "strconv" "strings" - "sync" "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS int8 + TOS uint8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -500,9 +500,13 @@ type WriteOptions struct { type SockOptBool int const ( + // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS + // ancillary message is passed with incoming packets. + ReceiveTOSOption SockOptBool = iota + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption SockOptBool = iota + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. @@ -552,7 +556,7 @@ type ReusePortOption int // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. -type BindToDeviceOption string +type BindToDeviceOption NICID // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 1f735d735..2d20f7ef3 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package timer_test +package tcpip_test import ( "sync" diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d8c5b5058..3aa23d529 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c7ce74cdd..42afb3f5b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,8 +15,7 @@ package icmp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -289,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c toCopy := *to to = &toCopy - netProto, err := e.checkV4Mapped(to, true) + netProto, err := e.checkV4Mapped(to) if err != nil { return 0, nil, err } @@ -476,18 +475,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - return 0, tcpip.ErrNoRoute - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -519,7 +512,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -632,7 +625,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 44b58ff6b..4858d150c 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 07ffa8aba..fc5bc69fa 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,8 +25,7 @@ package packet import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 00991ac8e..2f2131ff7 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 85f7eb76b..ee9c4c58b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,8 +26,7 @@ package raw import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3b353d56c..0e3ab05ad 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -16,6 +16,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + go_library( name = "tcp", srcs = [ @@ -23,6 +35,7 @@ go_library( "connect.go", "cubic.go", "cubic_state.go", + "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", @@ -38,6 +51,7 @@ go_library( "segment_state.go", "snd.go", "snd_state.go", + "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", ], @@ -45,9 +59,9 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ - "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5422ae80c..1a2e3efa9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,11 +19,11 @@ import ( "encoding/binary" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() return nil, tcpip.ErrConnectionAborted } @@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() { // instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state + state := e.EndpointState() e.pendingAccepted.Add(1) defer e.pendingAccepted.Done() acceptedChan := e.acceptedChan e.mu.Unlock() + if state == StateListen { acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) @@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished n.isConnectNotified = true + n.setEndpointState(StateEstablished) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cdd69f360..a2f384384 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,11 +16,11 @@ package tcp import ( "encoding/binary" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.ep.mu.Lock() - h.ep.state = StateSynRecv + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() } @@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // SYN-RCVD state. h.state = handshakeSynRcvd h.ep.mu.Lock() - h.ep.state = StateSynRecv ttl := h.ep.ttl + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it // permits SACK. This is not explicitly defined in the RFC but @@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { WS: h.rcvWndScale, TS: h.ep.sendTSOk, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } @@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error { WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } @@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } if e.sackPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error { } func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } // Drain the send queue. e.handleWrite() @@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } - e.state = StateError + e.setEndpointState(StateError) e.HardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the @@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } // completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). +// exit. func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. e.workerRunning = false if e.workerCleanup { e.cleanupLocked() @@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { e.rcvAutoParams.prevCopied = int(h.rcvWnd) e.rcvListMu.Unlock() } - h.ep.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished + e.setEndpointState(StateEstablished) } // transitionToStateCloseLocked ensures that the endpoint is @@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // delivered to this endpoint from the demuxer when the endpoint // is transitioned to StateClose. func (e *endpoint) transitionToStateCloseLocked() { - if e.state == StateClose { + if e.EndpointState() == StateClose { return } + // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() - e.state = StateClose + e.setEndpointState(StateClose) e.stack.Stats().TCP.EstablishedClosed.Increment() } @@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { s.decRef() return } - ep.(*endpoint).enqueueSegment(s) + if ep.(*endpoint).enqueueSegment(s) { + ep.(*endpoint).newSegmentWaker.Assert() + } } func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { @@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - s.decRef() e.mu.Lock() - switch e.state { + switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set // to indicate EPIPE. @@ -981,103 +984,53 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: e.mu.Unlock() + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) return false, tcpip.ErrConnectionReset } } return true, nil } -// handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState() == StateClose || e.EndpointState() == StateError { + return nil + } s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false break } - // Invoke the tcp probe if installed. - if e.probe != nil { - e.probe(e.completeState()) + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + return err } - - if s.flagIsSet(header.TCPFlagRst) { - if ok, err := e.handleReset(s); !ok { - return err - } - } else if s.flagIsSet(header.TCPFlagSyn) { - // See: https://tools.ietf.org/html/rfc5961#section-4.1 - // 1) If the SYN bit is set, irrespective of the sequence number, TCP - // MUST send an ACK (also referred to as challenge ACK) to the remote - // peer: - // - // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> - // - // After sending the acknowledgment, TCP MUST drop the unacceptable - // segment and stop processing further. - // - // By sending an ACK, the remote peer is challenged to confirm the loss - // of the previous connection and the request to start a new connection. - // A legitimate peer, after restart, would not have a TCB in the - // synchronized state. Thus, when the ACK arrives, the peer should send - // a RST segment back with the sequence number derived from the ACK - // field that caused the RST. - - // This RST will confirm that the remote peer has indeed closed the - // previous connection. Upon receipt of a valid RST, the local TCP - // endpoint MUST terminate its connection. The local TCP endpoint - // should then rely on SYN retransmission from the remote end to - // re-establish the connection. - - e.snd.sendAck() - } else if s.flagIsSet(header.TCPFlagAck) { - // Patch the window size in the segment according to the - // send window scale. - s.window <<= e.snd.sndWndScale - - // RFC 793, page 41 states that "once in the ESTABLISHED - // state all segments must carry current acknowledgment - // information." - drop, err := e.rcv.handleRcvdSegment(s) - if err != nil { - s.decRef() - return err - } - if drop { - s.decRef() - continue - } - - // Now check if the received segment has caused us to transition - // to a CLOSED state, if yes then terminate processing and do - // not invoke the sender. - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - s.decRef() - return nil - } - e.snd.handleRcvdSegment(s) + if !cont { + s.decRef() + return nil } - s.decRef() } - // If the queue is not empty, make sure we'll wake up in the next - // iteration. - if checkRequeue && !e.segmentQueue.empty() { + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } @@ -1086,11 +1039,88 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - e.resetKeepaliveTimer(true) + e.resetKeepaliveTimer(true /* receivedData */) return nil } +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. @@ -1160,7 +1190,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1182,6 +1212,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() + e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1193,7 +1224,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) e.mu.Lock() - h.ep.state = StateSynSent + h.ep.setEndpointState(StateSynSent) e.mu.Unlock() if err := h.execute(); err != nil { @@ -1202,12 +1233,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.state = StateError + e.setEndpointState(StateError) e.HardError = err // Lock released below. epilogue() - return err } } @@ -1215,7 +1245,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - // Tell waiters that the endpoint is connected and writable. e.mu.Lock() drained := e.drainDone != nil e.mu.Unlock() @@ -1224,8 +1253,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { <-e.undrain } - e.waiterQueue.Notify(waiter.EventOut) - // Set up the functions that will be called when the main protocol loop // wakes up. funcs := []struct { @@ -1241,17 +1268,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { f: e.handleClose, }, { - w: &e.newSegmentWaker, - f: e.handleSegments, - }, - { w: &closeWaker, f: func() *tcpip.Error { // This means the socket is being closed due - // to the TCP_FIN_WAIT2 timeout was hit. Just + // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() e.transitionToStateCloseLocked() + e.workerCleanup = true e.mu.Unlock() return nil }, @@ -1267,6 +1291,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { }, }, { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, + { w: &e.keepalive.waker, f: e.keepaliveTimerExpired, }, @@ -1293,14 +1323,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } if n¬ifyReset != 0 { - e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.mu.Unlock() + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset } if n¬ifyClose != 0 && closeTimer == nil { e.mu.Lock() - if e.state == StateFinWait2 && e.closed { + if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { @@ -1320,11 +1352,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(); err != nil { + if err := e.handleSegments(false /* fastPath */); err != nil { return err } } - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1349,14 +1381,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.AddWaker(funcs[i].w, i) } + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + // The following assertions and notifications are needed for restored // endpoints. Fresh newly created endpoints have empty states and should // not invoke any. - e.segmentQueue.mu.Lock() - if !e.segmentQueue.list.Empty() { + if !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } - e.segmentQueue.mu.Unlock() e.rcvListMu.Lock() if !e.rcvList.Empty() { @@ -1371,28 +1410,53 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Main loop. Handle segments until both send and receive ends of the // connection have completed. + cleanupOnError := func(err *tcpip.Error) { + e.mu.Lock() + e.workerCleanup = true + if err != nil { + e.resetConnectionLocked(err) + } + // Lock released below. + epilogue() + } - for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { +loop: + for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() - if err := funcs[v].f(); err != nil { - e.mu.Lock() - // Ensure we release all endpoint registration and route - // references as the connection is now in an error - // state. - e.workerCleanup = true - e.resetConnectionLocked(err) - // Lock released below. - epilogue() + // We need to double check here because the notification maybe + // stale by the time we got around to processing it. + // + // NOTE: since we now hold the workMu the processors cannot + // change the state of the endpoint so it's safe to proceed + // after this check. + switch e.EndpointState() { + case StateError: + // If the endpoint has already transitioned to an ERROR + // state just pass nil here as any reset that may need + // to be sent etc should already have been done and we + // just want to terminate the loop and cleanup the + // endpoint. + cleanupOnError(nil) return nil + case StateTimeWait: + fallthrough + case StateClose: + e.mu.Lock() + break loop + default: + if err := funcs[v].f(); err != nil { + cleanupOnError(err) + return nil + } + e.mu.Lock() } - e.mu.Lock() } - state := e.state + state := e.EndpointState() e.mu.Unlock() var reuseTW func() if state == StateTimeWait { @@ -1405,13 +1469,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.mu.Lock() + e.workerCleanup = true + e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() - if e.state != StateError { - e.stack.Stats().TCP.CurrentEstablished.Decrement() + if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1468,7 +1534,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { - tcpEP.enqueueSegment(s) + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() } // We explicitly do not decRef // the segment as it's still diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..e18012ac0 --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,224 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + newEndpointWaker sleep.Waker + id int +} + +func newProcessor(id int) *processor { + p := &processor{ + id: id, + } + go p.handleSegments() + return p +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +func (p *processor) handleSegments() { + const newEndpointWaker = 1 + s := sleep.Sleeper{} + s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + defer s.Done() + for { + s.Fetch(true) + for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state + // then just let the worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means + // that by the time we get to handleSegments the + // endpoint may not be in ESTABLISHED. But this should + // be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a + // CLOSED/ERROR state then handleSegments is a noop. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + continue + } + + if !ep.workMu.TryLock() { + ep.newSegmentWaker.Assert() + continue + } + // If the endpoint is in a connected state then we do + // direct delivery to ensure low latency and avoid + // scheduler interactions. + if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { + // Send any active resets if required. + if err != nil { + ep.mu.Lock() + ep.resetConnectionLocked(err) + ep.mu.Unlock() + } + ep.notifyProtocolGoroutine(notifyTickleWorker) + ep.workMu.Unlock() + continue + } + + if !ep.segmentQueue.empty() { + p.epQ.enqueue(ep) + } + + ep.workMu.Unlock() + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []*processor + seed uint32 +} + +func newDispatcher(nProcessors int) *dispatcher { + processors := []*processor{} + for i := 0; i < nProcessors; i++ { + processors = append(processors, newProcessor(i)) + } + return &dispatcher{ + processors: processors, + seed: generateRandUint32(), + } +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8)} + + h := jenkins.Sum32(d.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 2ac1b6877..4797f11d1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -19,12 +19,12 @@ import ( "fmt" "math" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -120,6 +120,7 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -127,6 +128,7 @@ const ( // ensures the loop terminates if the final state of the endpoint is // say TIME_WAIT. notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {} type endpoint struct { EndpointInfo + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -324,6 +338,7 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` + // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -359,7 +374,7 @@ type endpoint struct { workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -371,6 +386,8 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. + // + // recentTS must be read/written atomically. recentTS uint32 // tsOffset is a randomized offset added to the value of the @@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() { e.workMu.Unlock() } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp atomically sets the recentTS field to the +// provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + atomic.StoreUint32(&e.recentTS, recentTS) +} + +// recentTimestamp atomically reads and returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return atomic.LoadUint32(&e.recentTS) +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // @@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.mu.RLock() defer e.mu.RUnlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -733,14 +791,20 @@ func (e *endpoint) Close() { // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdown() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdown() { e.mu.Lock() // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false @@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { defer close(done) for n := range e.acceptedChan { n.notifyProtocolGoroutine(notifyReset) + // close all connections that have completed but + // not accepted by the application. n.Close() } }() @@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { e.closePendingAcceptableConnectionsLocked() } + e.workerCleanup = false if e.isRegistered { @@ -885,8 +953,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) { // reject valid data that might already be in flight as the // acceptable window will shrink. if rcvWnd > e.rcvBufSize { + availBefore := e.receiveBufferAvailableLocked() e.rcvBufSize = rcvWnd - e.notifyProtocolGoroutine(notifyReceiveWindowChanged) + availAfter := e.receiveBufferAvailableLocked() + mask := uint32(notifyReceiveWindowChanged) + if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + mask |= notifyNonZeroReceiveWindow + } + e.notifyProtocolGoroutine(mask) } // We only update prevCopied when we grow the buffer because in cases @@ -914,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError e.mu.RUnlock() @@ -938,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -955,11 +1029,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { } e.rcvBufUsed -= len(v) - // If the window was zero before this read and if the read freed up - // enough buffer space for the scaled window to be non-zero then notify - // the protocol goroutine to send a window update. - if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) { - e.zeroWindow = false + + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -973,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -1032,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + if opts.Atomic { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { + if e.workMu.TryLock() { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] } - } - - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - if e.workMu.TryLock() { // Do the work inline. e.handleWrite() e.workMu.Unlock() } else { + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + + } // Let the protocol goroutine do the work. e.sndWaker.Assert() } @@ -1084,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1096,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1133,16 +1252,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// zeroReceiveWindow checks if the receive window to be announced now would be -// zero, based on the amount of available buffer and the receive window scaling. +// windowCrossedACKThreshold checks if the receive window to be announced now +// would be under aMSS or under half receive buffer, whichever smaller. This is +// useful as a receive side silly window syndrome prevention mechanism. If +// window grows to reasonable value, we should send ACK to the sender to inform +// the rx space is now large. We also want ensure a series of small read()'s +// won't trigger a flood of spurious tiny ACK's. // -// It must be called with rcvListMu held. -func (e *endpoint) zeroReceiveWindow(scale uint8) bool { - if e.rcvBufUsed >= e.rcvBufSize { - return true +// For large receive buffers, the threshold is aMSS - once reader reads more +// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of +// receive buffer size. This is chosen arbitrairly. +// crossed will be true if the window size crossed the ACK threshold. +// above will be true if the new window is >= ACK threshold and false +// otherwise. +func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) { + newAvail := e.receiveBufferAvailableLocked() + oldAvail := newAvail - deltaBefore + if oldAvail < 0 { + oldAvail = 0 } - return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 + threshold := int(e.amss) + if threshold > e.rcvBufSize/2 { + threshold = e.rcvBufSize / 2 + } + + switch { + case oldAvail < threshold && newAvail >= threshold: + return true, true + case oldAvail >= threshold && newAvail < threshold: + return true, false + } + return false, false } // SetSockOptBool sets a socket option. @@ -1158,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1204,10 +1345,16 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { size = math.MaxInt32 / 2 } + availBefore := e.receiveBufferAvailableLocked() e.rcvBufSize = size + availAfter := e.receiveBufferAvailableLocked() + e.rcvAutoParams.disabled = true - if e.zeroWindow && !e.zeroReceiveWindow(scale) { - e.zeroWindow = false + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() @@ -1279,19 +1426,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.QuickAckOption: if v == 0 { @@ -1372,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Acquire the work mutex as we may need to // reinitialize the congestion control state. e.mu.Lock() - state := e.state + state := e.EndpointState() e.cc = v e.mu.Unlock() switch state { case StateEstablished: e.workMu.Lock() e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } e.mu.Unlock() @@ -1442,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1550,12 +1692,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.QuickAckOption: @@ -1665,26 +1803,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -1720,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1732,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } nicID := addr.NIC - switch e.state { + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1839,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() e.boundNICID = nicID e.effectiveNetProtos = netProtos @@ -1860,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1885,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags finQueued := false switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1897,8 +2019,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) e.mu.Unlock() + // Try to send an active reset immediately if the + // work mutex is available. + if e.workMu.TryLock() { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.notifyProtocolGoroutine(notifyTickleWorker) + e.mu.Unlock() + e.workMu.Unlock() + } else { + e.notifyProtocolGoroutine(notifyReset) + } return nil } } @@ -1920,11 +2052,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { finQueued = true // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() } - case e.state == StateListen: + case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1965,7 +2096,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { + if e.EndpointState() == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1983,7 +2114,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { // The listen is called on an unbound socket, the socket is // automatically bound to a random free port with the local // address set to INADDR_ANY. @@ -1993,7 +2124,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } @@ -2004,24 +2135,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } e.workerRunning = true - go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.mu.Lock() e.waiterQueue = waiterQueue e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection @@ -2031,7 +2165,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -2058,7 +2192,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } @@ -2120,7 +2254,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } @@ -2142,7 +2276,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2153,45 +2287,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { - s := newSegment(r, id, pkt) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } - - e.enqueueSegment(s) + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. } -func (e *endpoint) enqueueSegment(s *segment) { +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. @@ -2234,13 +2345,10 @@ func (e *endpoint) readyToRead(s *segment) { if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() - // Check if the receive window is now closed. If so make sure - // we set the zero window before we deliver the segment to ensure - // that a subsequent read of the segment will correctly trigger - // a non-zero notification. - if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + // Increase counter if the receive window falls down below MSS + // or half receive buffer size, whichever smaller. + if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() - e.zeroWindow = true } e.rcvList.PushBack(s) } else { @@ -2311,8 +2419,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2322,7 +2430,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } @@ -2411,7 +2519,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2518,9 +2626,7 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 7aa4c3f0e..4a46f0ec5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,9 +16,10 @@ package tcp import ( "fmt" - "sync" + "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound: // TODO(b/138137272): this enumeration duplicates // EndpointState.connected. remove it. @@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() { fallthrough case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } - fallthrough case StateError, StateClose: - for (e.state == StateError || e.state == StateClose) && e.workerRunning { + for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() } if e.workerRunning { - panic("endpoint still has worker running in closed or error state") + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) } default: - panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) } if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { + if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { // saveState is invoked by stateify. func (e *endpoint) saveState() EndpointState { - return e.state + return e.EndpointState() } // Endpoint loading must be done in the following ordering by their state, to @@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - if state.connected() { + // For restore purposes we treat TimeWait like a connected endpoint. + if state.connected() || state == StateTimeWait { connectedLoading.Add(1) } switch state { @@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) { case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } - e.state = state + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference to increment + // metrics is not yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - // Freeze segment queue before registering to prevent any segments - // from being delivered while it is being restored. e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. @@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) { e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() state := e.origEndpointState - switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = StateClose + e.setEndpointState(StateClose) tcpip.AsyncLoading.Done() }() } @@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } + } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 4983bca81..7eb613be5 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -15,8 +15,7 @@ package tcp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index bc718064c..958c06fa7 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,10 +21,11 @@ package tcp import ( + "runtime" "strings" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -104,6 +105,7 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + dispatcher *dispatcher } // Number returns the tcp protocol number. @@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. // @@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 0a5534959..958f03ac1 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -98,12 +98,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // in such cases we may need to send an ack to indicate to our peer that it can // resume sending data. func (r *receiver) nonZeroWindow() { - if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 { - // We never got around to announcing a zero window size, so we - // don't need to immediately announce a nonzero one. - return - } - // Immediately send an ack. r.ep.snd.sendAck() } @@ -175,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateEstablished: - r.ep.state = StateCloseWait + r.ep.setEndpointState(StateCloseWait) case StateFinWait1: if s.flagIsSet(header.TCPFlagAck) { // FIN-ACK, transition to TIME-WAIT. - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } else { // Simultaneous close, expecting a final ACK. - r.ep.state = StateClosing + r.ep.setEndpointState(StateClosing) } case StateFinWait2: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } r.ep.mu.Unlock() @@ -211,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateFinWait1: - r.ep.state = StateFinWait2 + r.ep.setEndpointState(StateFinWait2) // Notify protocol goroutine that we have received an // ACK to our FIN so that it can start the FIN_WAIT2 // timer to abort connection if the other side does // not close within 2MSL. r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) case StateLastAck: r.ep.transitionToStateCloseLocked() } @@ -273,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo switch state { case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { - s.decRef() // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -290,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - s.decRef() return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -320,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - s.decRef() return true, tcpip.ErrConnectionAborted } } @@ -342,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { r.ep.mu.RLock() - state := r.ep.state + state := r.ep.EndpointState() closed := r.ep.closed r.ep.mu.RUnlock() diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index e0759225e..bd20a7ee9 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -15,7 +15,7 @@ package tcp import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8a947dc66..b74b61e7d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -16,11 +16,11 @@ package tcp import ( "math" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -442,6 +442,13 @@ func (s *sender) retransmitTimerExpired() bool { return true } + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases + // when writeList is empty. Remove this once we have a proper fix for this + // issue. + if s.writeList.Front() == nil { + return true + } + s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() @@ -698,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) - // Transition to FIN-WAIT1 state since we're initiating an active close. - s.ep.mu.Lock() - switch s.ep.state { + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { case StateCloseWait: - // We've already received a FIN and are now sending our own. The - // sender is now awaiting a final ACK for this FIN. - s.ep.state = StateLastAck + s.ep.setEndpointState(StateLastAck) default: - s.ep.state = StateFinWait1 + s.ep.setEndpointState(StateFinWait1) } - s.ep.mu.Unlock() + } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 15745ebd4..a9dfbe857 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SeqNum(uint32(c.IRS+1)), checker.AckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(0), checker.TCPFlags(header.TCPFlagRst), @@ -1083,12 +1085,12 @@ func TestTrafficClassV6(t *testing.T) { func TestConnectBindToDevice(t *testing.T) { for _, test := range []struct { name string - device string + device tcpip.NICID want tcp.EndpointState }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, } { t.Run(test.name, func(t *testing.T) { c := context.New(t, defaultMTU) @@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. @@ -2091,10 +2096,14 @@ func TestZeroScaledWindowReceive(t *testing.T) { ) } - // Read some data. An ack should be sent in response to that. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %v", err) + // Read at least 1MSS of data. An ack should be sent in response to that. + sz := 0 + for sz < defaultMTU { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + sz += len(v) } checker.IPv4(t, c.GetPacket(), @@ -2103,7 +2112,7 @@ func TestZeroScaledWindowReceive(t *testing.T) { checker.DstPort(context.TestPort), checker.SeqNum(uint32(c.IRS)+1), checker.AckNum(uint32(790+sent)), - checker.Window(uint16(len(v)>>ws)), + checker.Window(uint16(sz>>ws)), checker.TCPFlags(header.TCPFlagAck), ), ) @@ -3794,47 +3803,41 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { + if err := s.CreateNIC(321, loopback.New()); err != nil { t.Errorf("CreateNIC failed: %v", err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } @@ -5443,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { rawEP.SendPacketWithTS(b[start:start+mss], tsVal) packetsSent++ } + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -5510,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { stk := c.Stack() // Set lower limits for auto-tuning tests. This is required because the // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. + // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { @@ -5565,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalSent += mss packetsSent++ } + // Resume it so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -6562,3 +6567,140 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) } } + +func TestIncreaseWindowOnReceive(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after recv() when the window grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // We now have < 1 MSS in the buffer space. Read the data! An + // ack should be sent in response to that. The window was not + // zero, but it grew to larger than MSS. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("Read failed: %v", err) + } + + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("Read failed: %v", err) + } + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestIncreaseWindowOnBufferResize(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after available recv buffer grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // Increasing the buffer from should generate an ACK, + // since window grew from small value to larger equal MSS + c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 97e4d5825..57ff123e3 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -30,6 +30,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1a5ee6317..c9cbed8f4 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,7 @@ package udp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -32,6 +31,7 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -114,6 +114,10 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -244,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + cm := tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + } + e.mu.RLock() + receiveTOS := e.receiveTOS + e.mu.RUnlock() + if receiveTOS { + cm.HasTOS = true + cm.TOS = p.tos + } + return p.data.ToView(), cm, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -403,7 +418,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to, false) + netProto, err := e.checkV4Mapped(to) if err != nil { return 0, nil, err } @@ -459,6 +474,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = v + e.mu.Unlock() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -502,7 +523,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa, false) + netProto, err := e.checkV4Mapped(&fa) if err != nil { return err } @@ -631,19 +652,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.BroadcastOption: e.mu.Lock() @@ -670,15 +686,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.ReceiveTOSOption: + e.mu.RLock() + v := e.receiveTOS + e.mu.RUnlock() + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { return false, tcpip.ErrUnknownProtocolOption } - e.mu.Lock() + e.mu.RLock() v := e.v6only - e.mu.Unlock() + e.mu.RUnlock() return v, nil } @@ -767,12 +789,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = tcpip.BindToDeviceOption("") + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.KeepaliveEnabledOption: @@ -849,35 +867,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if len(addr.Addr) == 0 { - return netProto, nil - } - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - - // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.ID.LocalAddress) == 16 { - return 0, tcpip.ErrNetworkUnreachable - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -926,7 +921,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -1084,7 +1079,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, true) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -1248,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() + // Save any useful information from the network header to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + } + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 149fff999..ee9d10555 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,6 +56,7 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, + TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -513,42 +515,37 @@ func TestBindToDeviceOption(t *testing.T) { t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } @@ -557,8 +554,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { c.t.Helper() payload := newPayload() @@ -573,12 +570,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) + v, cm, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, _, err = c.ep.Read(&addr) + v, cm, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -611,15 +608,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1287,7 +1290,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1322,7 +1325,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1349,6 +1352,47 @@ func TestTOSV6(t *testing.T) { } } +func TestReceiveTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) + } + + want := true + if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { + c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) + } + + got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) + } + if got != want { + c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) + } + + // Verify that the correct received TOS is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + testRead(c, flow, checker.ReceiveTOS(testTOS)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { |