diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp.go | 62 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 187 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 131 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 192 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 56 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 89 |
10 files changed, 704 insertions, 98 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index b8f9517d0..705e984c1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( "//pkg/ilist", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", @@ -83,6 +84,7 @@ go_test( embed = [":stack"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..403557fd7 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..1baa498d0 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 35825ebf7..a9dd322db 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -115,6 +115,30 @@ var ( MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) +// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an +// NDP Router Advertisement informed the Stack about. +type DHCPv6ConfigurationFromNDPRA int + +const ( + // DHCPv6NoConfiguration indicates that no configurations are available via + // DHCPv6. + DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + + // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. + // + // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 + // will return all available configuration information. + DHCPv6ManagedAddress + + // DHCPv6OtherConfigurations indicates that other configuration information is + // available via DHCPv6. + // + // Other configurations are configurations other than addresses. Examples of + // other configurations are recursive DNS server list, DNS search lists and + // default gateway. + DHCPv6OtherConfigurations +) + // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { @@ -194,7 +218,20 @@ type NDPDispatcher interface { // already known DNS servers. If called with known DNS servers, their // valid lifetimes must be refreshed to lifetime (it may be increased, // decreased, or completely invalidated when lifetime = 0). + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + + // OnDHCPv6Configuration will be called with an updated configuration that is + // available via DHCPv6 for a specified NIC. + // + // NDPDispatcher assumes that the initial configuration available by DHCPv6 is + // DHCPv6NoConfiguration. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } // NDPConfigurations is the NDP configurations for the netstack. @@ -281,6 +318,9 @@ type ndpState struct { // The addresses generated by SLAAC. autoGenAddresses map[tcpip.Address]autoGenAddressState + + // The last learned DHCPv6 configuration from an NDP RA. + dhcpv6Configuration DHCPv6ConfigurationFromNDPRA } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -533,6 +573,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { return } + // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we + // only inform the dispatcher on configuration changes. We do nothing else + // with the information. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + var configuration DHCPv6ConfigurationFromNDPRA + switch { + case ra.ManagedAddrConfFlag(): + configuration = DHCPv6ManagedAddress + + case ra.OtherConfFlag(): + configuration = DHCPv6OtherConfigurations + + default: + configuration = DHCPv6NoConfiguration + } + + if ndp.dhcpv6Configuration != configuration { + ndp.dhcpv6Configuration = configuration + ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + } + } + // Is the NIC configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d334af289..d390c6312 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -162,18 +162,24 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDHCPv6Event struct { + nicID tcpip.NICID + configuration stack.DHCPv6ConfigurationFromNDPRA +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent + dhcpv6ConfigurationC chan ndpDHCPv6Event } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -280,6 +286,16 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { + if c := n.dhcpv6ConfigurationC; c != nil { + c <- ndpDHCPv6Event{ + nicID, + configuration, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -797,21 +813,32 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options +// and DHCPv6 configurations specified. +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - ra := header.NDPRouterAdvert(pkt.NDPPayload()) + raPayload := pkt.NDPPayload() + ra := header.NDPRouterAdvert(raPayload) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(raPayload[2:], rl) + // Populate the Managed Address flag field. + if managedAddress { + // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) + // of the RA payload. + raPayload[1] |= (1 << 7) + } + // Populate the Other Configurations flag field. + if otherConfigurations { + // The Other Configurations flag field is the 6th bit of byte #1 + // (0-indexing) of the RA payload. + raPayload[1] |= (1 << 6) + } opts := ra.Options() opts.Serialize(optSer) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -826,6 +853,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} } +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) +} + +// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related +// fields set. +// +// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the +// DHCPv6 related ones. +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +} + // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the @@ -1688,9 +1732,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd return ndpDisp, e, s } -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { +// addrForNewConnectionTo returns the local address used when creating a new +// connection to addr. +func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -1704,8 +1750,8 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + if err := ep.Connect(addr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", addr, err) } got, err := ep.GetLocalAddress() if err != nil { @@ -1714,9 +1760,19 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { return got.Addr } +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + t.Helper() + + return addrForNewConnectionTo(t, s, dstAddr) +} + // addrForNewConnectionWithAddr returns the local address used when creating a // new connection with a specific local address. func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -2951,3 +3007,94 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { default: } } + +// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly +// informed when new information about what configurations are available via +// DHCPv6 is learned. +func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + t.Helper() + select { + case e := <-ndpDisp.dhcpv6ConfigurationC: + if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { + t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DHCPv6 configuration event") + } + } + + expectNoDHCPv6Event := func() { + t.Helper() + select { + case <-ndpDisp.dhcpv6ConfigurationC: + t.Fatal("unexpected DHCPv6 configuration event") + default: + } + } + + // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + // Receiving the same update again should not result in an event to the + // NDPDispatcher. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to none. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + // + // Note, when the M flag is set, the O flag is redundant. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + // Even though the DHCPv6 flags are different, the effective configuration is + // the same so we should not receive a new event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4144d5d0f..071221d5a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,10 +15,12 @@ package stack import ( + "log" + "sort" "strings" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -27,10 +29,11 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + context NICContext mu sync.RWMutex spoofing bool @@ -84,7 +87,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -98,6 +101,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { id: id, name: name, linkEP: ep, + context: ctx, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -249,13 +253,17 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint returns the primary endpoint of n for the given network -// protocol. -// // primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists. If no non-deprecated endpoint exists, the first deprecated -// endpoint will be returned. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { +// endpoint exists for the given protocol and remoteAddr. If no non-deprecated +// endpoint exists, the first deprecated endpoint will be returned. +// +// If an IPv6 primary endpoint is requested, Source Address Selection (as +// defined by RFC 6724 section 5) will be performed. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { + if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { + return n.primaryIPv6Endpoint(remoteAddr) + } + n.mu.RLock() defer n.mu.RUnlock() @@ -294,6 +302,103 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return deprecatedEndpoint } +// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC +// 6724 section 5). +type ipv6AddrCandidate struct { + ref *referencedNetworkEndpoint + scope header.IPv6AddressScope +} + +// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 are followed. +// +// remoteAddr must be a valid IPv6 address. +func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { + n.mu.RLock() + defer n.mu.RUnlock() + + primaryAddrs := n.primary[header.IPv6ProtocolNumber] + + if len(primaryAddrs) == 0 { + return nil + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) + for _, r := range primaryAddrs { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !r.isValidForOutgoing() { + continue + } + + addr := r.ep.ID().LocalAddress + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err) + } + + cs = append(cs, ipv6AddrCandidate{ + ref: r, + scope: scope, + }) + } + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.ref.ep.ID().LocalAddress == remoteAddr { + return true + } + if sb.ref.ep.ID().LocalAddress == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if r := c.ref; r.tryIncRef() { + return r + } + } + + return nil +} + // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { @@ -658,7 +763,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// Subnets returns the Subnets associated with this NIC. +// AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index fb7ac409e..386eb6eec 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,13 +21,13 @@ package stack import ( "encoding/binary" - "sync" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -547,6 +547,49 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } +// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address +// and returns the network protocol number to be used to communicate with the +// specified address. It returns an error if the passed address is incompatible +// with the receiver. +func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto + switch len(addr.Addr) { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + if header.IsV4MappedAddress(addr.Addr) { + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == header.IPv4Any { + addr.Addr = "" + } + } + } + + switch len(e.ID.LocalAddress) { + case header.IPv4AddressSize: + if len(addr.Addr) == header.IPv6AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + case header.IPv6AddressSize: + if len(addr.Addr) == header.IPv4AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + } + } + + switch { + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + if v6only { + return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + } + default: + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + + return addr, netProto, nil +} + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo // marker interface. func (*TransportEndpointInfo) IsEndpointInfo() {} @@ -796,6 +839,9 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } +// NICContext is an opaque pointer used to store client-supplied NIC metadata. +type NICContext interface{} + // NICOptions specifies the configuration of a NIC as it is being created. // The zero value creates an enabled, unnamed NIC. type NICOptions struct { @@ -805,6 +851,12 @@ type NICOptions struct { // Disabled specifies whether to avoid calling Attach on the passed // LinkEndpoint. Disabled bool + + // Context specifies user-defined data that will be returned in stack.NICInfo + // for the NIC. Clients of this library can use it to add metadata that + // should be tracked alongside a NIC, to avoid having to keep a + // map[tcpip.NICID]metadata mirroring stack.Stack's nic map. + Context NICContext } // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and @@ -819,7 +871,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, opts.Name, ep) + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n if !opts.Disabled { @@ -860,7 +912,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICSubnets returns a map of NICIDs to their associated subnets. +// NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() @@ -886,6 +938,18 @@ type NICInfo struct { MTU uint32 Stats NICStats + + // Context is user-supplied data optionally supplied in CreateNICWithOptions. + // See type NICOptions for more details. + Context NICContext +} + +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok } // NICInfo returns a map of NICIDs to their associated information. @@ -908,6 +972,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, + Context: nic.context, } } return nics @@ -1041,9 +1106,9 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { if len(localAddr) == 0 { - return nic.primaryEndpoint(netProto) + return nic.primaryEndpoint(netProto, remoteAddr) } return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } @@ -1059,7 +1124,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } @@ -1069,7 +1134,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } if nic, ok := s.nics[route.NIC]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9ac50bb23..4b3d18f1b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( @@ -2001,6 +2002,46 @@ func TestNICAutoGenAddr(t *testing.T) { } } +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) + } +} + // TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local // addresses with opaque interface identifiers. Link Local addresses should // always be generated with opaque IIDs if configured to use them, even if the @@ -2371,3 +2412,154 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { } } } + +func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { + const ( + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 + ) + + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. + tests := []struct { + name string + nicAddrs []tcpip.Address + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address + }{ + // Test Rule 1 of RFC 6724 section 5. + { + name: "Same Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Same Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test Rule 2 of RFC 6724 section 5. + { + name: "Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test returning the endpoint that is closest to the front when + // candidate addresses are "equal" from the perspective of RFC 6724 + // section 5. + { + name: "Unique Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Link Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local for Unique Local", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + + for _, a := range test.nicAddrs { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { + t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + } + } + + if t.Failed() { + t.FailNow() + } + + if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 67c21be42..d686e6eb8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -18,8 +18,8 @@ import ( "fmt" "math/rand" "sort" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + transEP := selectEndpoint(id, mpep, epsByNic.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNic.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } @@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { epsByNic.mu.Lock() defer epsByNic.mu.Unlock() @@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort } // This is a new binding. - multiPortEp := &multiPortEndpoint{} + multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} multiPortEp.endpointsMap = make(map[TransportEndpoint]int) multiPortEp.reuse = reusePort epsByNic.endpoints[bindToDevice] = multiPortEp @@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ endpoints: make(map[TransportEndpointID]*endpointsByNic), } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto + } } } @@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. @@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] for i, endpoint := range ep.endpointsArr { // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + break + } endpoint.HandlePacket(r, id, pkt) break } + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + continue + } endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. @@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if epsByNic, ok := eps.endpoints[id]; ok { // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // This is a new binding. @@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol } eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index df5ced887..5e9237de9 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -41,7 +41,7 @@ const ( type testContext struct { t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack ep tcpip.Endpoint @@ -66,27 +66,24 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - opts := stack.NICOptions{Name: linkEpName} - if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -105,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) return &testContext{ t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } @@ -122,7 +119,7 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -153,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -183,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) { func TestDistribution(t *testing.T) { type endpointSockopts struct { reuse int - bindToDevice string + bindToDevice tcpip.NICID } for _, test := range []struct { name string @@ -191,71 +188,71 @@ func TestDistribution(t *testing.T) { endpoints []endpointSockopts // wantedDistribution is the wanted ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantedDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": {0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, "dev0"}, - {0, "dev1"}, - {0, "dev2"}, + {0, 1}, + {0, 2}, + {0, 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": {1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": {0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": {0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, "dev0"}, - {1, "dev0"}, - {1, "dev1"}, - {1, "dev1"}, - {1, "dev1"}, - {1, ""}, + {1, 1}, + {1, 1}, + {1, 2}, + {1, 2}, + {1, 2}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": {0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": {0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { t.Run(test.name, func(t *testing.T) { for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string + t.Run(string(device), func(t *testing.T) { + var devices []tcpip.NICID for d := range test.wantedDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) defer c.cleanup() c.createV6Endpoint(false) |