diff options
Diffstat (limited to 'pkg/tcpip/stack')
30 files changed, 5012 insertions, 5877 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 1c58bed2d..2eaeab779 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,9 +54,10 @@ go_template_instance( go_library( name = "stack", srcs = [ + "addressable_endpoint_state.go", "conntrack.go", - "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -64,7 +65,6 @@ go_library( "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", - "ndp.go", "neighbor_cache.go", "neighbor_entry.go", "neighbor_entry_list.go", @@ -105,6 +105,7 @@ go_test( name = "stack_x_test", size = "medium", srcs = [ + "addressable_endpoint_state_test.go", "ndp_test.go", "nud_test.go", "stack_test.go", @@ -115,12 +116,14 @@ go_test( deps = [ ":stack", "//pkg/rand", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", @@ -136,12 +139,12 @@ go_test( name = "stack_test", size = "small", srcs = [ - "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", + "packet_buffer_test.go", ], library = ":stack", deps = [ @@ -149,8 +152,8 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", - "@com_github_dpjacques_clockwork//:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go new file mode 100644 index 000000000..db8ac1c2b --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -0,0 +1,758 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) +var _ AddressableEndpoint = (*AddressableEndpointState)(nil) + +// AddressableEndpointState is an implementation of an AddressableEndpoint. +type AddressableEndpointState struct { + networkEndpoint NetworkEndpoint + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + endpoints map[tcpip.Address]*addressState + primary []*addressState + + // groups holds the mapping between group addresses and the number of times + // they have been joined. + groups map[tcpip.Address]uint32 + } +} + +// Init initializes the AddressableEndpointState with networkEndpoint. +// +// Must be called before calling any other function on m. +func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { + a.networkEndpoint = networkEndpoint + + a.mu.Lock() + defer a.mu.Unlock() + a.mu.endpoints = make(map[tcpip.Address]*addressState) + a.mu.groups = make(map[tcpip.Address]uint32) +} + +// ReadOnlyAddressableEndpointState provides read-only access to an +// AddressableEndpointState. +type ReadOnlyAddressableEndpointState struct { + inner *AddressableEndpointState +} + +// AddrOrMatching returns an endpoint for the passed address that is consisdered +// bound to the wrapped AddressableEndpointState. +// +// If addr is an exact match with an existing address, that address is returned. +// Otherwise, f is called with each address and the address that f returns true +// for is returned. +// +// Returns nil of no address matches. +func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + if ep, ok := m.inner.mu.endpoints[addr]; ok { + if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { + return ep + } + } + + for _, ep := range m.inner.mu.endpoints { + if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { + return ep + } + } + + return nil +} + +// Lookup returns the AddressEndpoint for the passed address. +// +// Returns nil if the passed address is not associated with the +// AddressableEndpointState. +func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + ep, ok := m.inner.mu.endpoints[addr] + if !ok { + return nil + } + return ep +} + +// ForEach calls f for each address pair. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + + for _, ep := range m.inner.mu.endpoints { + if !f(ep) { + return + } + } +} + +// ForEachPrimaryEndpoint calls f for each primary address. +// +// If f returns false, f is no longer be called. +func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { + m.inner.mu.RLock() + defer m.inner.mu.RUnlock() + for _, ep := range m.inner.mu.primary { + f(ep) + } +} + +// ReadOnly returns a readonly reference to a. +func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { + return ReadOnlyAddressableEndpointState{inner: a} +} + +func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.releaseAddressStateLocked(addrState) +} + +// releaseAddressState removes addrState from s's address state (primary and endpoints list). +// +// Preconditions: a.mu must be write locked. +func (a *AddressableEndpointState) releaseAddressStateLocked(addrState *addressState) { + oldPrimary := a.mu.primary + for i, s := range a.mu.primary { + if s == addrState { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + oldPrimary[len(oldPrimary)-1] = nil + break + } + } + delete(a.mu.endpoints, addrState.addr.Address) +} + +// AddAndAcquirePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, configType, deprecated, true /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// AddAndAcquireTemporaryAddress adds a temporary address. +// +// Returns tcpip.ErrDuplicateAddress if the address exists. +// +// The temporary address's endpoint is acquired and returned. +func (a *AddressableEndpointState) AddAndAcquireTemporaryAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior) (AddressEndpoint, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + ep, err := a.addAndAcquireAddressLocked(addr, peb, AddressConfigStatic, false /* deprecated */, false /* permanent */) + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil, err + } + return ep, err +} + +// addAndAcquireAddressLocked adds, acquires and returns a permanent or +// temporary address. +// +// If the addressable endpoint already has the address in a non-permanent state, +// and addAndAcquireAddressLocked is adding a permanent address, that address is +// promoted in place and its properties set to the properties provided. If the +// address already exists in any other state, then tcpip.ErrDuplicateAddress is +// returned, regardless the kind of address that is being added. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated, permanent bool) (*addressState, *tcpip.Error) { + // attemptAddToPrimary is false when the address is already in the primary + // address list. + attemptAddToPrimary := true + addrState, ok := a.mu.endpoints[addr.Address] + if ok { + if !permanent { + // We are adding a non-permanent address but the address exists. No need + // to go any further since we can only promote existing temporary/expired + // addresses to permanent. + return nil, tcpip.ErrDuplicateAddress + } + + addrState.mu.Lock() + if addrState.mu.kind.IsPermanent() { + addrState.mu.Unlock() + // We are adding a permanent address but a permanent address already + // exists. + return nil, tcpip.ErrDuplicateAddress + } + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("found an address that should have been released (ref count == 0); address = %s", addrState.addr)) + } + + // We now promote the address. + for i, s := range a.mu.primary { + if s == addrState { + switch peb { + case CanBePrimaryEndpoint: + // The address is already in the primary address list. + attemptAddToPrimary = false + case FirstPrimaryEndpoint: + if i == 0 { + // The address is already first in the primary address list. + attemptAddToPrimary = false + } else { + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + } + case NeverPrimaryEndpoint: + a.mu.primary = append(a.mu.primary[:i], a.mu.primary[i+1:]...) + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + break + } + } + } + + if addrState == nil { + addrState = &addressState{ + addressableEndpointState: a, + addr: addr, + } + a.mu.endpoints[addr.Address] = addrState + addrState.mu.Lock() + // We never promote an address to temporary - it can only be added as such. + // If we are actaully adding a permanent address, it is promoted below. + addrState.mu.kind = Temporary + } + + // At this point we have an address we are either promoting from an expired or + // temporary address to permanent, promoting an expired address to temporary, + // or we are adding a new temporary or permanent address. + // + // The address MUST be write locked at this point. + defer addrState.mu.Unlock() + + if permanent { + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("only non-permanent addresses should be promoted to permanent; address = %s", addrState.addr)) + } + + // Primary addresses are biased by 1. + addrState.mu.refs++ + addrState.mu.kind = Permanent + } + // Acquire the address before returning it. + addrState.mu.refs++ + addrState.mu.deprecated = deprecated + addrState.mu.configType = configType + + if attemptAddToPrimary { + switch peb { + case NeverPrimaryEndpoint: + case CanBePrimaryEndpoint: + a.mu.primary = append(a.mu.primary, addrState) + case FirstPrimaryEndpoint: + if cap(a.mu.primary) == len(a.mu.primary) { + a.mu.primary = append([]*addressState{addrState}, a.mu.primary...) + } else { + // Shift all the endpoints by 1 to make room for the new address at the + // front. We could have just created a new slice but this saves + // allocations when the slice has capacity for the new address. + primaryCount := len(a.mu.primary) + a.mu.primary = append(a.mu.primary, nil) + if n := copy(a.mu.primary[1:], a.mu.primary); n != primaryCount { + panic(fmt.Sprintf("copied %d elements; expected = %d elements", n, primaryCount)) + } + a.mu.primary[0] = addrState + } + default: + panic(fmt.Sprintf("unrecognized primary endpoint behaviour = %d", peb)) + } + } + + return addrState, nil +} + +// RemovePermanentAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.mu.groups[addr]; ok { + panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) + } + + return a.removePermanentAddressLocked(addr) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + addrState, ok := a.mu.endpoints[addr] + if !ok { + return tcpip.ErrBadLocalAddress + } + + return a.removePermanentEndpointLocked(addrState) +} + +// RemovePermanentEndpoint removes the passed endpoint if it is associated with +// a and permanent. +func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *tcpip.Error { + addrState, ok := ep.(*addressState) + if !ok || addrState.addressableEndpointState != a { + return tcpip.ErrInvalidEndpointState + } + + return a.removePermanentEndpointLocked(addrState) +} + +// removePermanentAddressLocked is like RemovePermanentAddress but with locking +// requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) removePermanentEndpointLocked(addrState *addressState) *tcpip.Error { + if !addrState.GetKind().IsPermanent() { + return tcpip.ErrBadLocalAddress + } + + addrState.SetKind(PermanentExpired) + a.decAddressRefLocked(addrState) + return nil +} + +// decAddressRef decrements the address's reference count and releases it once +// the reference count hits 0. +func (a *AddressableEndpointState) decAddressRef(addrState *addressState) { + a.mu.Lock() + defer a.mu.Unlock() + a.decAddressRefLocked(addrState) +} + +// decAddressRefLocked is like decAddressRef but with locking requirements. +// +// Precondition: a.mu must be write locked. +func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState) { + addrState.mu.Lock() + defer addrState.mu.Unlock() + + if addrState.mu.refs == 0 { + panic(fmt.Sprintf("attempted to decrease ref count for AddressEndpoint w/ addr = %s when it is already released", addrState.addr)) + } + + addrState.mu.refs-- + + if addrState.mu.refs != 0 { + return + } + + // A non-expired permanent address must not have its reference count dropped + // to 0. + if addrState.mu.kind.IsPermanent() { + panic(fmt.Sprintf("permanent addresses should be removed through the AddressableEndpoint: addr = %s, kind = %d", addrState.addr, addrState.mu.kind)) + } + + a.releaseAddressStateLocked(addrState) +} + +// MainAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.GetKind() == Permanent + }) + if ep == nil { + return tcpip.AddressWithPrefix{} + } + + addr := ep.AddressWithPrefix() + a.decAddressRefLocked(ep) + return addr +} + +// acquirePrimaryAddressRLocked returns an acquired primary address that is +// valid according to isValid. +// +// Precondition: e.mu must be read locked +func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState { + var deprecatedEndpoint *addressState + for _, ep := range a.mu.primary { + if !isValid(ep) { + continue + } + + if !ep.Deprecated() { + if ep.IncRef() { + // ep is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + a.decAddressRefLocked(deprecatedEndpoint) + deprecatedEndpoint = nil + } + + return ep + } + } else if deprecatedEndpoint == nil && ep.IncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of + // ep in case a doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, ep's reference count + // will be decremented. + deprecatedEndpoint = ep + } + } + + return deprecatedEndpoint +} + +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + a.mu.Lock() + defer a.mu.Unlock() + + if addrState, ok := a.mu.endpoints[localAddr]; ok { + if !addrState.IsAssigned(allowTemp) { + return nil + } + + if !addrState.IncRef() { + panic(fmt.Sprintf("failed to increase the reference count for address = %s", addrState.addr)) + } + + return addrState + } + + if !allowTemp { + return nil + } + + addr := localAddr.WithPrefix() + ep, err := a.addAndAcquireAddressLocked(addr, tempPEB, AddressConfigStatic, false /* deprecated */, false /* permanent */) + if err != nil { + // addAndAcquireAddressLocked only returns an error if the address is + // already assigned but we just checked above if the address exists so we + // expect no error. + panic(fmt.Sprintf("a.addAndAcquireAddressLocked(%s, %d, %d, false, false): %s", addr, tempPEB, AddressConfigStatic, err)) + } + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since addAndAcquireAddressLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + return ep +} + +// AcquireOutgoingPrimaryAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() + + ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool { + return ep.IsAssigned(allowExpired) + }) + + // From https://golang.org/doc/faq#nil_error: + // + // Under the covers, interfaces are implemented as two elements, a type T and + // a value V. + // + // An interface value is nil only if the V and T are both unset, (T=nil, V is + // not set), In particular, a nil interface will always hold a nil type. If we + // store a nil pointer of type *int inside an interface value, the inner type + // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such + // an interface value will therefore be non-nil even when the pointer value V + // inside is nil. + // + // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type, + // we need to explicitly return nil below if ep is (a typed) nil. + if ep == nil { + return nil + } + + return ep +} + +// PrimaryAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PrimaryAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.primary { + // Don't include tentative, expired or temporary endpoints + // to avoid confusion and prevent the caller from using + // those. + switch ep.GetKind() { + case PermanentTentative, PermanentExpired, Temporary: + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// PermanentAddresses implements AddressableEndpoint. +func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefix { + a.mu.RLock() + defer a.mu.RUnlock() + + var addrs []tcpip.AddressWithPrefix + for _, ep := range a.mu.endpoints { + if !ep.GetKind().IsPermanent() { + continue + } + + addrs = append(addrs, ep.AddressWithPrefix()) + } + + return addrs +} + +// JoinGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) + if err != nil { + return false, err + } + // We have no need for the address endpoint. + a.decAddressRefLocked(ep) + } + + a.mu.groups[group] = joins + 1 + return !ok, nil +} + +// LeaveGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { + a.mu.Lock() + defer a.mu.Unlock() + + joins, ok := a.mu.groups[group] + if !ok { + return false, tcpip.ErrBadLocalAddress + } + + if joins == 1 { + a.removeGroupAddressLocked(group) + delete(a.mu.groups, group) + return true, nil + } + + a.mu.groups[group] = joins - 1 + return false, nil +} + +// IsInGroup implements GroupAddressableEndpoint. +func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { + a.mu.RLock() + defer a.mu.RUnlock() + _, ok := a.mu.groups[group] + return ok +} + +func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { + if err := a.removePermanentAddressLocked(group); err != nil { + // removePermanentEndpointLocked would only return an error if group is + // not bound to the addressable endpoint, but we know it MUST be assigned + // since we have group in our map of groups. + panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) + } +} + +// Cleanup forcefully leaves all groups and removes all permanent addresses. +func (a *AddressableEndpointState) Cleanup() { + a.mu.Lock() + defer a.mu.Unlock() + + for group := range a.mu.groups { + a.removeGroupAddressLocked(group) + } + a.mu.groups = make(map[tcpip.Address]uint32) + + for _, ep := range a.mu.endpoints { + // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is + // not a permanent address. + if err := a.removePermanentEndpointLocked(ep); err != nil && err != tcpip.ErrBadLocalAddress { + panic(fmt.Sprintf("unexpected error from removePermanentEndpointLocked(%s): %s", ep.addr, err)) + } + } +} + +var _ AddressEndpoint = (*addressState)(nil) + +// addressState holds state for an address. +type addressState struct { + addressableEndpointState *AddressableEndpointState + addr tcpip.AddressWithPrefix + + // Lock ordering (from outer to inner lock ordering): + // + // AddressableEndpointState.mu + // addressState.mu + mu struct { + sync.RWMutex + + refs uint32 + kind AddressKind + configType AddressConfigType + deprecated bool + } +} + +// NetworkEndpoint implements AddressEndpoint. +func (a *addressState) NetworkEndpoint() NetworkEndpoint { + return a.addressableEndpointState.networkEndpoint +} + +// AddressWithPrefix implements AddressEndpoint. +func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { + return a.addr +} + +// GetKind implements AddressEndpoint. +func (a *addressState) GetKind() AddressKind { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.kind +} + +// SetKind implements AddressEndpoint. +func (a *addressState) SetKind(kind AddressKind) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.kind = kind +} + +// IsAssigned implements AddressEndpoint. +func (a *addressState) IsAssigned(allowExpired bool) bool { + if !a.addressableEndpointState.networkEndpoint.Enabled() { + return false + } + + switch a.GetKind() { + case PermanentTentative: + return false + case PermanentExpired: + return allowExpired + default: + return true + } +} + +// IncRef implements AddressEndpoint. +func (a *addressState) IncRef() bool { + a.mu.Lock() + defer a.mu.Unlock() + if a.mu.refs == 0 { + return false + } + + a.mu.refs++ + return true +} + +// DecRef implements AddressEndpoint. +func (a *addressState) DecRef() { + a.addressableEndpointState.decAddressRef(a) +} + +// ConfigType implements AddressEndpoint. +func (a *addressState) ConfigType() AddressConfigType { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.configType +} + +// SetDeprecated implements AddressEndpoint. +func (a *addressState) SetDeprecated(d bool) { + a.mu.Lock() + defer a.mu.Unlock() + a.mu.deprecated = d +} + +// Deprecated implements AddressEndpoint. +func (a *addressState) Deprecated() bool { + a.mu.RLock() + defer a.mu.RUnlock() + return a.mu.deprecated +} diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go new file mode 100644 index 000000000..26787d0a3 --- /dev/null +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -0,0 +1,77 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// TestAddressableEndpointStateCleanup tests that cleaning up an addressable +// endpoint state removes permanent addresses and leaves groups. +func TestAddressableEndpointStateCleanup(t *testing.T) { + var ep fakeNetworkEndpoint + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + var s stack.AddressableEndpointState + s.Init(&ep) + + addr := tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + } + + { + ep, err := s.AddAndAcquirePermanentAddress(addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */) + if err != nil { + t.Fatalf("s.AddAndAcquirePermanentAddress(%s, %d, %d, false): %s", addr, stack.NeverPrimaryEndpoint, stack.AddressConfigStatic, err) + } + // We don't need the address endpoint. + ep.DecRef() + } + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep == nil { + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address) + } + ep.DecRef() + } + + group := tcpip.Address("\x02") + if added, err := s.JoinGroup(group); err != nil { + t.Fatalf("s.JoinGroup(%s): %s", group, err) + } else if !added { + t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) + } + if !s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) + } + + s.Cleanup() + { + ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) + } + } + if s.IsInGroup(group) { + t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + } +} diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 470c265aa..457f0c89b 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -199,12 +199,12 @@ type bucket struct { func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader) - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol } @@ -268,7 +268,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -281,8 +281,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.MinIP - replyTID.srcPort = rt.MinPort + replyTID.srcAddr = rt.Addr + replyTID.srcPort = rt.Port var manip manipType switch hook { case Prerouting: @@ -344,8 +344,8 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -377,8 +377,8 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -396,8 +396,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -423,7 +422,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return false } @@ -433,8 +432,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return true } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return false } @@ -455,7 +454,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) return false } @@ -474,7 +473,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return } @@ -486,7 +485,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) ct.insertConn(conn) } @@ -573,7 +572,9 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // reapTupleLocked tries to remove tuple and its reply from the table. It // returns whether the tuple's connection has timed out. // -// Preconditions: ct.mu is locked for reading and bucket is locked. +// Preconditions: +// * ct.mu is locked for reading. +// * bucket is locked. func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { if !tuple.conn.timedOut(now) { return false diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go deleted file mode 100644 index 92c8cb534..000000000 --- a/pkg/tcpip/stack/fake_time_test.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "container/heap" - "sync" - "time" - - "github.com/dpjacques/clockwork" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type fakeClock struct { - clock clockwork.FakeClock - - // mu protects the fields below. - mu sync.RWMutex - - // times is min-heap of times. A heap is used for quick retrieval of the next - // upcoming time of scheduled work. - times *timeHeap - - // waitGroups stores one WaitGroup for all work scheduled to execute at the - // same time via AfterFunc. This allows parallel execution of all functions - // passed to AfterFunc scheduled for the same time. - waitGroups map[time.Time]*sync.WaitGroup -} - -func newFakeClock() *fakeClock { - return &fakeClock{ - clock: clockwork.NewFakeClock(), - times: &timeHeap{}, - waitGroups: make(map[time.Time]*sync.WaitGroup), - } -} - -var _ tcpip.Clock = (*fakeClock)(nil) - -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (fc *fakeClock) NowNanoseconds() int64 { - return fc.clock.Now().UnixNano() -} - -// NowMonotonic implements tcpip.Clock.NowMonotonic. -func (fc *fakeClock) NowMonotonic() int64 { - return fc.NowNanoseconds() -} - -// AfterFunc implements tcpip.Clock.AfterFunc. -func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := fc.clock.Now().Add(d) - wg := fc.addWait(until) - return &fakeTimer{ - clock: fc, - until: until, - timer: fc.clock.AfterFunc(d, func() { - defer wg.Done() - f() - }), - } -} - -// addWait adds an additional wait to the WaitGroup for parallel execution of -// all work scheduled for t. Returns a reference to the WaitGroup modified. -func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { - fc.mu.RLock() - wg, ok := fc.waitGroups[t] - fc.mu.RUnlock() - - if ok { - wg.Add(1) - return wg - } - - fc.mu.Lock() - heap.Push(fc.times, t) - fc.mu.Unlock() - - wg = &sync.WaitGroup{} - wg.Add(1) - - fc.mu.Lock() - fc.waitGroups[t] = wg - fc.mu.Unlock() - - return wg -} - -// removeWait removes a wait from the WaitGroup for parallel execution of all -// work scheduled for t. -func (fc *fakeClock) removeWait(t time.Time) { - fc.mu.RLock() - defer fc.mu.RUnlock() - - wg := fc.waitGroups[t] - wg.Done() -} - -// advance executes all work that have been scheduled to execute within d from -// the current fake time. Blocks until all work has completed execution. -func (fc *fakeClock) advance(d time.Duration) { - // Block until all the work is done - until := fc.clock.Now().Add(d) - for { - fc.mu.Lock() - if fc.times.Len() == 0 { - fc.mu.Unlock() - return - } - - t := heap.Pop(fc.times).(time.Time) - if t.After(until) { - // No work to do - heap.Push(fc.times, t) - fc.mu.Unlock() - return - } - fc.mu.Unlock() - - diff := t.Sub(fc.clock.Now()) - fc.clock.Advance(diff) - - fc.mu.RLock() - wg := fc.waitGroups[t] - fc.mu.RUnlock() - - wg.Wait() - - fc.mu.Lock() - delete(fc.waitGroups, t) - fc.mu.Unlock() - } -} - -type fakeTimer struct { - clock *fakeClock - timer clockwork.Timer - - mu sync.RWMutex - until time.Time -} - -var _ tcpip.Timer = (*fakeTimer)(nil) - -// Reset implements tcpip.Timer.Reset. -func (ft *fakeTimer) Reset(d time.Duration) { - if !ft.timer.Reset(d) { - return - } - - ft.mu.Lock() - defer ft.mu.Unlock() - - ft.clock.removeWait(ft.until) - ft.until = ft.clock.clock.Now().Add(d) - ft.clock.addWait(ft.until) -} - -// Stop implements tcpip.Timer.Stop. -func (ft *fakeTimer) Stop() bool { - if !ft.timer.Stop() { - return false - } - - ft.mu.RLock() - defer ft.mu.RUnlock() - - ft.clock.removeWait(ft.until) - return true -} - -type timeHeap []time.Time - -var _ heap.Interface = (*timeHeap)(nil) - -func (h timeHeap) Len() int { - return len(h) -} - -func (h timeHeap) Less(i, j int) bool { - return h[i].Before(h[j]) -} - -func (h timeHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -func (h *timeHeap) Push(x interface{}) { - *h = append(*h, x.(time.Time)) -} - -func (h *timeHeap) Pop() interface{} { - last := (*h)[len(*h)-1] - *h = (*h)[:len(*h)-1] - return last -} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c962693f5..4e4b00a92 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,37 +46,37 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fwdTestNetworkEndpoint struct { + AddressableEndpointState + nicID tcpip.NICID - id NetworkEndpointID - prefixLen int proto *fwdTestNetworkProtocol dispatcher TransportDispatcher ep LinkEndpoint } -func (f *fwdTestNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + +func (*fwdTestNetworkEndpoint) Enable() *tcpip.Error { + return nil } -func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (*fwdTestNetworkEndpoint) Enabled() bool { + return true } -func (f *fwdTestNetworkEndpoint) PrefixLen() int { - return f.prefixLen +func (*fwdTestNetworkEndpoint) Disable() {} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { - return &f.id -} - func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -86,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr return 0 } -func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -97,9 +94,9 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] - b[srcAddrOffset] = f.id.LocalAddress[0] + b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) @@ -114,17 +111,26 @@ func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBu return tcpip.ErrNotSupported } -func (*fwdTestNetworkEndpoint) Close() {} +func (f *fwdTestNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() +} // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { addrCache *linkAddrCache + neigh *neighborCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) + onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) + + mu struct { + sync.RWMutex + forwarding bool + } } +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -144,42 +150,40 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add } func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = netHeader - pkt.Data.TrimFront(fwdTestNetHeaderLen) - return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { - return &fwdTestNetworkEndpoint{ - nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, +func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint { + e := &fwdTestNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, - }, nil + ep: nic.LinkEndpoint(), + } + e.AddressableEndpointState.Init(e) + return e } -func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func (f *fwdTestNetworkProtocol) Close() {} +func (*fwdTestNetworkProtocol) Close() {} -func (f *fwdTestNetworkProtocol) Wait() {} +func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { - if f.addrCache != nil && f.onLinkAddressResolved != nil { + if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) + f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) }) } return nil @@ -192,10 +196,25 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip return "", false } -func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding + +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + // fwdTestPacketInfo holds all the information about an outbound packet. type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress @@ -290,7 +309,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: &PacketBuffer{Data: vv}, + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), } select { @@ -314,16 +333,19 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocol{proto}, + NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + UseNeighborCache: useNeighborCache, }) - proto.addrCache = s.linkAddrCache + if !useNeighborCache { + proto.addrCache = s.linkAddrCache + } // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ @@ -351,6 +373,15 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f t.Fatal("AddAddress #2 failed:", err) } + if useNeighborCache { + // Control the neighbor cache for NIC 2. + nic, ok := s.nics[2] + if !ok { + t.Fatal("failed to get the neighbor cache for NIC 2") + } + proto.neigh = nic.neigh + } + // Route all packets to NIC 2. { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -364,79 +395,129 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) - var p fwdTestPacketInfo + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } + var p fwdTestPacketInfo - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } func TestForwardingWithFakeResolver(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any address will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - var p fwdTestPacketInfo + var p fwdTestPacketInfo - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } @@ -444,15 +525,17 @@ func TestForwardingWithNoResolver(t *testing.T) { // Create a network protocol without a resolver. proto := &fwdTestNetworkProtocol{} - ep1, ep2 := fwdTestNetFactory(t, proto) + // Whether or not we use the neighbor cache here does not matter since + // neither linkAddrCache nor neighborCache will be used. + ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) select { case <-ep2.C: @@ -462,202 +545,334 @@ func TestForwardingWithNoResolver(t *testing.T) { } func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - } + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) - } + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) - } + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). - if !ok { - t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ - Data: buf.ToVectorisedView(), - }) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) - } + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/headertype_string.go index d199ded6a..5efddfaaf 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/stack/headertype_string.go @@ -2,8 +2,7 @@ // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// +// You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software @@ -12,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. +// Code generated by "stringer -type headerType ."; DO NOT EDIT. package stack @@ -22,19 +21,19 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[DHCPv6NoConfiguration-1] - _ = x[DHCPv6ManagedAddress-2] - _ = x[DHCPv6OtherConfigurations-3] + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] } -const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" -var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} -func (i DHCPv6ConfigurationFromNDPRA) String() string { - i -= 1 - if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { - return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")" +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" } - return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 110ba073d..faa503b00 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -57,14 +57,14 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - tables: [numTables]Table{ + v4Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -83,9 +83,9 @@ func DefaultTables() *IPTables { }, mangleID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -101,10 +101,75 @@ func DefaultTables() *IPTables { }, filterID: Table{ Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + v6Tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -165,18 +230,21 @@ func EmptyNATTable() Table { } // GetTable returns a table by name. -func (it *IPTables) GetTable(name string) (Table, bool) { +func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { id, ok := nameToID[name] if !ok { return Table{}, false } it.mu.RLock() defer it.mu.RUnlock() - return it.tables[id], true + if ipv6 { + return it.v6Tables[id], true + } + return it.v4Tables[id], true } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { +func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { id, ok := nameToID[name] if !ok { return tcpip.ErrInvalidOptionValue @@ -190,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { it.startReaper(reaperDelay) } it.modified = true - it.tables[id] = table + if ipv6 { + it.v6Tables[id] = table + } else { + it.v4Tables[id] = table + } return nil } @@ -213,8 +285,15 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// which address and nicName can be gathered. Currently, address is only +// needed for prerouting and nicName is only needed for output. +// // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { + if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { + return true + } // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() @@ -235,9 +314,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr if tableID == natID && pkt.NatDone { continue } - table := it.tables[tableID] + var table Table + if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber { + table = it.v6Tables[tableID] + } else { + table = it.v4Tables[tableID] + } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -248,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -315,8 +399,8 @@ func (it *IPTables) startReaper(interval time.Duration) { // should not go forward. // // Preconditions: -// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// - pkt.NetworkHeader is not nil. +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -341,13 +425,13 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * } // Preconditions: -// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// - pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { case RuleAccept: return chainAccept @@ -364,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -388,13 +472,13 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId } // Preconditions: -// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// - pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(pkt, hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -413,11 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + it.mu.RLock() + defer it.mu.RUnlock() + if !it.modified { + return "", 0, tcpip.ErrNotConnected + } return it.connections.originalDst(epID) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index dc88033c7..611564b08 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -21,85 +21,146 @@ import ( ) // AcceptTarget accepts packets. -type AcceptTarget struct{} +type AcceptTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (at *AcceptTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: at.NetworkProtocol, + } +} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } // DropTarget drops packets. -type DropTarget struct{} +type DropTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (dt *DropTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: dt.NetworkProtocol, + } +} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. -type ErrorTarget struct{} +type ErrorTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (et *ErrorTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: et.NetworkProtocol, + } +} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } // UserChainTarget marks a rule as the beginning of a user chain. type UserChainTarget struct { + // Name is the chain name. Name string + + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (uc *UserChainTarget) ID() TargetID { + return TargetID{ + Name: ErrorTargetName, + NetworkProtocol: uc.NetworkProtocol, + } } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } // ReturnTarget returns from the current chain. If the chain is a built-in, the // hook's underflow should be called. -type ReturnTarget struct{} +type ReturnTarget struct { + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} + +// ID implements Target.ID. +func (rt *ReturnTarget) ID() TargetID { + return TargetID{ + NetworkProtocol: rt.NetworkProtocol, + } +} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const RedirectTargetName = "REDIRECT" + // RedirectTarget redirects the packet by modifying the destination port/IP. -// Min and Max values for IP and Ports in the struct indicate the range of -// values which can be used to redirect. +// TODO(gvisor.dev/issue/170): Other flags need to be added after we support +// them. type RedirectTarget struct { - // TODO(gvisor.dev/issue/170): Other flags need to be added after - // we support them. - // RangeProtoSpecified flag indicates single port is specified to - // redirect. - RangeProtoSpecified bool + // Addr indicates address used to redirect. + Addr tcpip.Address - // MinIP indicates address used to redirect. - MinIP tcpip.Address + // Port indicates port used to redirect. + Port uint16 - // MaxIP indicates address used to redirect. - MaxIP tcpip.Address - - // MinPort indicates port used to redirect. - MinPort uint16 + // NetworkProtocol is the network protocol the target is used with. + NetworkProtocol tcpip.NetworkProtocolNumber +} - // MaxPort indicates port used to redirect. - MaxPort uint16 +// ID implements Target.ID. +func (rt *RedirectTarget) ID() TargetID { + return TargetID{ + Name: RedirectTargetName, + NetworkProtocol: rt.NetworkProtocol, + } } // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 } // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { return RuleDrop, 0 } @@ -107,28 +168,25 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // to primary address of the incoming interface in Prerouting. switch hook { case Output: - rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) - rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) case Prerouting: - rt.MinIP = address - rt.MaxIP = address + rt.Addr = address default: panic("redirect target is supported only on output and prerouting hooks") } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader) - udpHeader.SetDestinationPort(rt.MinPort) + udpHeader := header.UDP(pkt.TransportHeader().View()) + udpHeader.SetDestinationPort(rt.Port) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { @@ -141,7 +199,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso } } // Change destination address. - netHeader.SetDestinationAddress(rt.MinIP) + netHeader.SetDestinationAddress(rt.Addr) netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) pkt.NatDone = true diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 73274ada9..7b3f3e88b 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "strings" "sync" @@ -81,31 +82,42 @@ const ( // // +stateify savable type IPTables struct { - // mu protects tables, priorities, and modified. + // mu protects v4Tables, v6Tables, and modified. mu sync.RWMutex - - // tables maps tableIDs to tables. Holds builtin tables only, not user - // tables. mu must be locked for accessing. - tables [numTables]Table - - // priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. mu needs to be locked for accessing. - priorities [NumHooks][]tableID - + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. mu must be locked for accessing. + v4Tables [numTables]Table + v6Tables [numTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. It is immutable. + priorities [NumHooks][]tableID + connections ConnTrack - // reaperDone can be signalled to stop the reaper goroutine. + // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} } -// A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules. +// A Table defines a set of chains and hooks into the network stack. +// +// It is a list of Rules, entry points (BuiltinChains), and error handlers +// (Underflows). As packets traverse netstack, they hit hooks. When a packet +// hits a hook, iptables compares it to Rules starting from that hook's entry +// point. So if a packet hits the Input hook, we look up the corresponding +// entry point in BuiltinChains and jump to that point. +// +// If the Rule doesn't match the packet, iptables continues to the next Rule. +// If a Rule does match, it can issue a verdict on the packet (e.g. RuleAccept +// or RuleDrop) that causes the packet to stop traversing iptables. It can also +// jump to other rules or perform custom actions based on Rule.Target. +// +// Underflow Rules are invoked when a chain returns without reaching a verdict. // // +stateify savable type Table struct { @@ -148,13 +160,18 @@ type Rule struct { Target Target } -// IPHeaderFilter holds basic IP filtering data common to every rule. +// IPHeaderFilter performs basic IP header matching common to every rule. // // +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber + // CheckProtocol determines whether the Protocol field should be + // checked during matching. + // TODO(gvisor.dev/issue/3549): Check this field during matching. + CheckProtocol bool + // Dst matches the destination IP address. Dst tcpip.Address @@ -191,16 +208,43 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } -// match returns whether hdr matches the filter. -func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. +// match returns whether pkt matches the filter. +// +// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 +// or IPv6 header length. +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { + // Extract header fields. + var ( + // TODO(gvisor.dev/issue/170): Support other filter fields. + transProto tcpip.TransportProtocolNumber + dstAddr tcpip.Address + srcAddr tcpip.Address + ) + switch proto := pkt.NetworkProtocolNumber; proto { + case header.IPv4ProtocolNumber: + hdr := header.IPv4(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + case header.IPv6ProtocolNumber: + hdr := header.IPv6(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + default: + panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto)) + } + // Check the transport protocol. - if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + if fl.CheckProtocol && fl.Protocol != transProto { return false } - // Check the source and destination IPs. - if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + // Check the addresses. + if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) || + !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) { return false } @@ -228,6 +272,18 @@ func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool return true } +// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header +// applies. +func (fl IPHeaderFilter) NetworkProtocol() tcpip.NetworkProtocolNumber { + switch len(fl.Src) { + case header.IPv4AddressSize: + return header.IPv4ProtocolNumber + case header.IPv6AddressSize: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprintf("invalid address in IPHeaderFilter: %s", fl.Src)) +} + // filterAddress returns whether addr matches the filter. func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { matches := true @@ -253,8 +309,23 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } +// A TargetID uniquely identifies a target. +type TargetID struct { + // Name is the target name as stored in the xt_entry_target struct. + Name string + + // NetworkProtocol is the protocol to which the target applies. + NetworkProtocol tcpip.NetworkProtocolNumber + + // Revision is the version of the target. + Revision uint8 +} + // A Target is the interface for taking an action for a packet. type Target interface { + // ID uniquely identifies the Target. + ID() TargetID + // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index b15b8d1cb..33806340e 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math" "sync/atomic" "testing" "time" @@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) { } func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) + // There is a race condition causing this test to fail when the executor + // takes longer than the resolution timeout to call linkAddrCache.get. This + // is especially common when this test is run with gotsan. + // + // Using a large resolution timeout decreases the probability of experiencing + // this race condition and does not affect how long this test takes to run. + c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) @@ -275,3 +282,71 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } + +// TestCacheWaker verifies that RemoveWaker removes a waker previously added +// through get(). +func TestCacheWaker(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + + // First, sanity check that wakers are working. + { + linkRes := &testLinkAddressResolver{cache: c} + s := sleep.Sleeper{} + defer s.Done() + + const wakerID = 1 + w := sleep.Waker{} + s.AddWaker(&w, wakerID) + + e := testAddrs[0] + + if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) + } + id, ok := s.Fetch(true /* block */) + if !ok { + t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") + } + if id != wakerID { + t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) + } + + if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { + t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) + } else if got != e.linkAddr { + t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) + } + } + + // Check that RemoveWaker works. + { + linkRes := &testLinkAddressResolver{cache: c} + s := sleep.Sleeper{} + defer s.Done() + + const wakerID = 2 // different than the ID used in the sanity check + w := sleep.Waker{} + s.AddWaker(&w, wakerID) + + e := testAddrs[1] + linkRes.onLinkAddressRequest = func() { + // Remove the waker before the linkAddrCache has the opportunity to send + // a notification. + c.removeWaker(e.addr, &w) + } + + if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) + } + + if got, err := getBlocking(c, e.addr, linkRes); err != nil { + t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) + } else if got != e.linkAddr { + t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Fatalf("unexpected notification from waker with id %d", id) + } + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go deleted file mode 100644 index 5174e639c..000000000 --- a/pkg/tcpip/stack/ndp.go +++ /dev/null @@ -1,1965 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "fmt" - "log" - "math/rand" - "time" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -const ( - // defaultDupAddrDetectTransmits is the default number of NDP Neighbor - // Solicitation messages to send when doing Duplicate Address Detection - // for a tentative address. - // - // Default = 1 (from RFC 4862 section 5.1) - defaultDupAddrDetectTransmits = 1 - - // defaultMaxRtrSolicitations is the default number of Router - // Solicitation messages to send when a NIC becomes enabled. - // - // Default = 3 (from RFC 4861 section 10). - defaultMaxRtrSolicitations = 3 - - // defaultRtrSolicitationInterval is the default amount of time between - // sending Router Solicitation messages. - // - // Default = 4s (from 4861 section 10). - defaultRtrSolicitationInterval = 4 * time.Second - - // defaultMaxRtrSolicitationDelay is the default maximum amount of time - // to wait before sending the first Router Solicitation message. - // - // Default = 1s (from 4861 section 10). - defaultMaxRtrSolicitationDelay = time.Second - - // defaultHandleRAs is the default configuration for whether or not to - // handle incoming Router Advertisements as a host. - defaultHandleRAs = true - - // defaultDiscoverDefaultRouters is the default configuration for - // whether or not to discover default routers from incoming Router - // Advertisements, as a host. - defaultDiscoverDefaultRouters = true - - // defaultDiscoverOnLinkPrefixes is the default configuration for - // whether or not to discover on-link prefixes from incoming Router - // Advertisements' Prefix Information option, as a host. - defaultDiscoverOnLinkPrefixes = true - - // defaultAutoGenGlobalAddresses is the default configuration for - // whether or not to generate global IPv6 addresses in response to - // receiving a new Prefix Information option with its Autonomous - // Address AutoConfiguration flag set, as a host. - // - // Default = true. - defaultAutoGenGlobalAddresses = true - - // minimumRtrSolicitationInterval is the minimum amount of time to wait - // between sending Router Solicitation messages. This limit is imposed - // to make sure that Router Solicitation messages are not sent all at - // once, defeating the purpose of sending the initial few messages. - minimumRtrSolicitationInterval = 500 * time.Millisecond - - // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait - // before sending the first Router Solicitation message. It is 0 because - // we cannot have a negative delay. - minimumMaxRtrSolicitationDelay = 0 - - // MaxDiscoveredDefaultRouters is the maximum number of discovered - // default routers. The stack should stop discovering new routers after - // discovering MaxDiscoveredDefaultRouters routers. - // - // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and - // SHOULD be more. - MaxDiscoveredDefaultRouters = 10 - - // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered - // on-link prefixes. The stack should stop discovering new on-link - // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link - // prefixes. - MaxDiscoveredOnLinkPrefixes = 10 - - // validPrefixLenForAutoGen is the expected prefix length that an - // address can be generated for. Must be 64 bits as the interface - // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so - // 128 - 64 = 64. - validPrefixLenForAutoGen = 64 - - // defaultAutoGenTempGlobalAddresses is the default configuration for whether - // or not to generate temporary SLAAC addresses. - defaultAutoGenTempGlobalAddresses = true - - // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime - // for temporary SLAAC addresses generated as part of RFC 4941. - // - // Default = 7 days (from RFC 4941 section 5). - defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour - - // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime - // for temporary SLAAC addresses generated as part of RFC 4941. - // - // Default = 1 day (from RFC 4941 section 5). - defaultMaxTempAddrPreferredLifetime = 24 * time.Hour - - // defaultRegenAdvanceDuration is the default duration before the deprecation - // of a temporary address when a new address will be generated. - // - // Default = 5s (from RFC 4941 section 5). - defaultRegenAdvanceDuration = 5 * time.Second - - // minRegenAdvanceDuration is the minimum duration before the deprecation - // of a temporary address when a new address will be generated. - minRegenAdvanceDuration = time.Duration(0) - - // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt - // SLAAC address regenerations in response to a NIC-local conflict. - maxSLAACAddrLocalRegenAttempts = 10 -) - -var ( - // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid - // Lifetime to update the valid lifetime of a generated address by - // SLAAC. - // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // - // Min = 2hrs. - MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour - - // MaxDesyncFactor is the upper bound for the preferred lifetime's desync - // factor for temporary SLAAC addresses. - // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // - // Must be greater than 0. - // - // Max = 10m (from RFC 4941 section 5). - MaxDesyncFactor = 10 * time.Minute - - // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the - // maximum preferred lifetime for temporary SLAAC addresses. - // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // - // This value guarantees that a temporary address will be preferred for at - // least 1hr if the SLAAC prefix is valid for at least that time. - MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour - - // MinMaxTempAddrValidLifetime is the minimum value allowed for the - // maximum valid lifetime for temporary SLAAC addresses. - // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // - // This value guarantees that a temporary address will be valid for at least - // 2hrs if the SLAAC prefix is valid for at least that time. - MinMaxTempAddrValidLifetime = 2 * time.Hour -) - -// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an -// NDP Router Advertisement informed the Stack about. -type DHCPv6ConfigurationFromNDPRA int - -const ( - _ DHCPv6ConfigurationFromNDPRA = iota - - // DHCPv6NoConfiguration indicates that no configurations are available via - // DHCPv6. - DHCPv6NoConfiguration - - // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. - // - // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 - // will return all available configuration information. - DHCPv6ManagedAddress - - // DHCPv6OtherConfigurations indicates that other configuration information is - // available via DHCPv6. - // - // Other configurations are configurations other than addresses. Examples of - // other configurations are recursive DNS server list, DNS search lists and - // default gateway. - DHCPv6OtherConfigurations -) - -// NDPDispatcher is the interface integrators of netstack must implement to -// receive and handle NDP related events. -type NDPDispatcher interface { - // OnDuplicateAddressDetectionStatus will be called when the DAD process - // for an address (addr) on a NIC (with ID nicID) completes. resolved - // will be set to true if DAD completed successfully (no duplicate addr - // detected); false otherwise (addr was detected to be a duplicate on - // the link the NIC is a part of, or it was stopped for some other - // reason, such as the address being removed). If an error occured - // during DAD, err will be set and resolved must be ignored. - // - // This function is not permitted to block indefinitely. This function - // is also not permitted to call into the stack. - OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) - - // OnDefaultRouterDiscovered will be called when a new default router is - // discovered. Implementations must return true if the newly discovered - // router should be remembered. - // - // This function is not permitted to block indefinitely. This function - // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool - - // OnDefaultRouterInvalidated will be called when a discovered default - // router that was remembered is invalidated. - // - // This function is not permitted to block indefinitely. This function - // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) - - // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true if the newly discovered - // on-link prefix should be remembered. - // - // This function is not permitted to block indefinitely. This function - // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool - - // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix that was remembered is invalidated. - // - // This function is not permitted to block indefinitely. This function - // is also not permitted to call into the stack. - OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) - - // OnAutoGenAddress will be called when a new prefix with its - // autonomous address-configuration flag set has been received and SLAAC - // has been performed. Implementations may prevent the stack from - // assigning the address to the NIC by returning false. - // - // This function is not permitted to block indefinitely. It must not - // call functions on the stack itself. - OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool - - // OnAutoGenAddressDeprecated will be called when an auto-generated - // address (as part of SLAAC) has been deprecated, but is still - // considered valid. Note, if an address is invalidated at the same - // time it is deprecated, the deprecation event MAY be omitted. - // - // This function is not permitted to block indefinitely. It must not - // call functions on the stack itself. - OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) - - // OnAutoGenAddressInvalidated will be called when an auto-generated - // address (as part of SLAAC) has been invalidated. - // - // This function is not permitted to block indefinitely. It must not - // call functions on the stack itself. - OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) - - // OnRecursiveDNSServerOption will be called when an NDP option with - // recursive DNS servers has been received. Note, addrs may contain - // link-local addresses. - // - // It is up to the caller to use the DNS Servers only for their valid - // lifetime. OnRecursiveDNSServerOption may be called for new or - // already known DNS servers. If called with known DNS servers, their - // valid lifetimes must be refreshed to lifetime (it may be increased, - // decreased, or completely invalidated when lifetime = 0). - // - // This function is not permitted to block indefinitely. It must not - // call functions on the stack itself. - OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) - - // OnDNSSearchListOption will be called when an NDP option with a DNS - // search list has been received. - // - // It is up to the caller to use the domain names in the search list - // for only their valid lifetime. OnDNSSearchListOption may be called - // with new or already known domain names. If called with known domain - // names, their valid lifetimes must be refreshed to lifetime (it may - // be increased, decreased or completely invalidated when lifetime = 0. - OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) - - // OnDHCPv6Configuration will be called with an updated configuration that is - // available via DHCPv6 for a specified NIC. - // - // This function is not permitted to block indefinitely. It must not - // call functions on the stack itself. - OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) -} - -// NDPConfigurations is the NDP configurations for the netstack. -type NDPConfigurations struct { - // The number of Neighbor Solicitation messages to send when doing - // Duplicate Address Detection for a tentative address. - // - // Note, a value of zero effectively disables DAD. - DupAddrDetectTransmits uint8 - - // The amount of time to wait between sending Neighbor solicitation - // messages. - // - // Must be greater than or equal to 1ms. - RetransmitTimer time.Duration - - // The number of Router Solicitation messages to send when the NIC - // becomes enabled. - MaxRtrSolicitations uint8 - - // The amount of time between transmitting Router Solicitation messages. - // - // Must be greater than or equal to 0.5s. - RtrSolicitationInterval time.Duration - - // The maximum amount of time before transmitting the first Router - // Solicitation message. - // - // Must be greater than or equal to 0s. - MaxRtrSolicitationDelay time.Duration - - // HandleRAs determines whether or not Router Advertisements will be - // processed. - HandleRAs bool - - // DiscoverDefaultRouters determines whether or not default routers will - // be discovered from Router Advertisements. This configuration is - // ignored if HandleRAs is false. - DiscoverDefaultRouters bool - - // DiscoverOnLinkPrefixes determines whether or not on-link prefixes - // will be discovered from Router Advertisements' Prefix Information - // option. This configuration is ignored if HandleRAs is false. - DiscoverOnLinkPrefixes bool - - // AutoGenGlobalAddresses determines whether or not global IPv6 - // addresses will be generated for a NIC in response to receiving a new - // Prefix Information option with its Autonomous Address - // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). - // - // Note, if an address was already generated for some unique prefix, as - // part of SLAAC, this option does not affect whether or not the - // lifetime(s) of the generated address changes; this option only - // affects the generation of new addresses as part of SLAAC. - AutoGenGlobalAddresses bool - - // AutoGenAddressConflictRetries determines how many times to attempt to retry - // generation of a permanent auto-generated address in response to DAD - // conflicts. - // - // If the method used to generate the address does not support creating - // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's - // MAC address), then no attempt will be made to resolve the conflict. - AutoGenAddressConflictRetries uint8 - - // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC - // addresses will be generated for a NIC as part of SLAAC privacy extensions, - // RFC 4941. - // - // Ignored if AutoGenGlobalAddresses is false. - AutoGenTempGlobalAddresses bool - - // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary - // SLAAC addresses. - MaxTempAddrValidLifetime time.Duration - - // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for - // temporary SLAAC addresses. - MaxTempAddrPreferredLifetime time.Duration - - // RegenAdvanceDuration is the duration before the deprecation of a temporary - // address when a new address will be generated. - RegenAdvanceDuration time.Duration -} - -// DefaultNDPConfigurations returns an NDPConfigurations populated with -// default values. -func DefaultNDPConfigurations() NDPConfigurations { - return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, - MaxRtrSolicitations: defaultMaxRtrSolicitations, - RtrSolicitationInterval: defaultRtrSolicitationInterval, - MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, - HandleRAs: defaultHandleRAs, - DiscoverDefaultRouters: defaultDiscoverDefaultRouters, - DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, - AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, - AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, - MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime, - MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime, - RegenAdvanceDuration: defaultRegenAdvanceDuration, - } -} - -// validate modifies an NDPConfigurations with valid values. If invalid values -// are present in c, the corresponding default values will be used instead. -func (c *NDPConfigurations) validate() { - if c.RetransmitTimer < minimumRetransmitTimer { - c.RetransmitTimer = defaultRetransmitTimer - } - - if c.RtrSolicitationInterval < minimumRtrSolicitationInterval { - c.RtrSolicitationInterval = defaultRtrSolicitationInterval - } - - if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { - c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay - } - - if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime { - c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime - } - - if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime { - c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime - } - - if c.RegenAdvanceDuration < minRegenAdvanceDuration { - c.RegenAdvanceDuration = minRegenAdvanceDuration - } -} - -// ndpState is the per-interface NDP state. -type ndpState struct { - // The NIC this ndpState is for. - nic *NIC - - // configs is the per-interface NDP configurations. - configs NDPConfigurations - - // The DAD state to send the next NS message, or resolve the address. - dad map[tcpip.Address]dadState - - // The default routers discovered through Router Advertisements. - defaultRouters map[tcpip.Address]defaultRouterState - - rtrSolicit struct { - // The timer used to send the next router solicitation message. - timer tcpip.Timer - - // Used to let the Router Solicitation timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the NIC this ndpState is associated with. MUST be set when the timer is - // set. - done *bool - } - - // The on-link prefixes discovered through Router Advertisements' Prefix - // Information option. - onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState - - // The SLAAC prefixes discovered through Router Advertisements' Prefix - // Information option. - slaacPrefixes map[tcpip.Subnet]slaacPrefixState - - // The last learned DHCPv6 configuration from an NDP RA. - dhcpv6Configuration DHCPv6ConfigurationFromNDPRA - - // temporaryIIDHistory is the history value used to generate a new temporary - // IID. - temporaryIIDHistory [header.IIDSize]byte - - // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for - // temporary SLAAC addresses. - temporaryAddressDesyncFactor time.Duration -} - -// dadState holds the Duplicate Address Detection timer and channel to signal -// to the DAD goroutine that DAD should stop. -type dadState struct { - // The DAD timer to send the next NS message, or resolve the address. - timer tcpip.Timer - - // Used to let the DAD timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the NIC this dadState is associated with. - done *bool -} - -// defaultRouterState holds data associated with a default router discovered by -// a Router Advertisement (RA). -type defaultRouterState struct { - // Job to invalidate the default router. - // - // Must not be nil. - invalidationJob *tcpip.Job -} - -// onLinkPrefixState holds data associated with an on-link prefix discovered by -// a Router Advertisement's Prefix Information option (PI) when the NDP -// configurations was configured to do so. -type onLinkPrefixState struct { - // Job to invalidate the on-link prefix. - // - // Must not be nil. - invalidationJob *tcpip.Job -} - -// tempSLAACAddrState holds state associated with a temporary SLAAC address. -type tempSLAACAddrState struct { - // Job to deprecate the temporary SLAAC address. - // - // Must not be nil. - deprecationJob *tcpip.Job - - // Job to invalidate the temporary SLAAC address. - // - // Must not be nil. - invalidationJob *tcpip.Job - - // Job to regenerate the temporary SLAAC address. - // - // Must not be nil. - regenJob *tcpip.Job - - createdAt time.Time - - // The address's endpoint. - // - // Must not be nil. - ref *referencedNetworkEndpoint - - // Has a new temporary SLAAC address already been regenerated? - regenerated bool -} - -// slaacPrefixState holds state associated with a SLAAC prefix. -type slaacPrefixState struct { - // Job to deprecate the prefix. - // - // Must not be nil. - deprecationJob *tcpip.Job - - // Job to invalidate the prefix. - // - // Must not be nil. - invalidationJob *tcpip.Job - - // Nonzero only when the address is not valid forever. - validUntil time.Time - - // Nonzero only when the address is not preferred forever. - preferredUntil time.Time - - // State associated with the stable address generated for the prefix. - stableAddr struct { - // The address's endpoint. - // - // May only be nil when the address is being (re-)generated. Otherwise, - // must not be nil as all SLAAC prefixes must have a stable address. - ref *referencedNetworkEndpoint - - // The number of times an address has been generated locally where the NIC - // already had the generated address. - localGenerationFailures uint8 - } - - // The temporary (short-lived) addresses generated for the SLAAC prefix. - tempAddrs map[tcpip.Address]tempSLAACAddrState - - // The next two fields are used by both stable and temporary addresses - // generated for a SLAAC prefix. This is safe as only 1 address will be - // in the generation and DAD process at any time. That is, no two addresses - // will be generated at the same time for a given SLAAC prefix. - - // The number of times an address has been generated and added to the NIC. - // - // Addresses may be regenerated in reseponse to a DAD conflicts. - generationAttempts uint8 - - // The maximum number of times to attempt regeneration of a SLAAC address - // in response to DAD conflicts. - maxGenerationAttempts uint8 -} - -// startDuplicateAddressDetection performs Duplicate Address Detection. -// -// This function must only be called by IPv6 addresses that are currently -// tentative. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { - // addr must be a valid unicast IPv6 address. - if !header.IsV6UnicastAddress(addr) { - return tcpip.ErrAddressFamilyNotSupported - } - - if ref.getKind() != permanentTentative { - // The endpoint should be marked as tentative since we are starting DAD. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) - } - - // Should not attempt to perform DAD on an address that is currently in the - // DAD process. - if _, ok := ndp.dad[addr]; ok { - // Should never happen because we should only ever call this function for - // newly created addresses. If we attemped to "add" an address that already - // existed, we would get an error since we attempted to add a duplicate - // address, or its reference count would have been increased without doing - // the work that would have been done for an address that was brand new. - // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) - } - - remaining := ndp.configs.DupAddrDetectTransmits - if remaining == 0 { - ref.setKind(permanent) - - // Consider DAD to have resolved even if no DAD messages were actually - // transmitted. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) - } - - return nil - } - - var done bool - var timer tcpip.Timer - // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the NIC's lock. This is effectively the same - // as starting a goroutine but we use a timer that fires immediately so we can - // reset it for the next DAD iteration. - timer = ndp.nic.stack.Clock().AfterFunc(0, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the NIC lock and stopped DAD - // before this function obtained the NIC lock. Simply return here and do - // nothing further. - return - } - - if ref.getKind() != permanentTentative { - // The endpoint should still be marked as tentative since we are still - // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) - } - - dadDone := remaining == 0 - - var err *tcpip.Error - if !dadDone { - // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) - - // Do not hold the lock when sending packets which may be a long running - // task or may block link address resolution. We know this is safe - // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the NIC. Note, DAD would be - // stopped if the NIC was disabled or removed, or if the address was - // removed. - ndp.nic.mu.Unlock() - err = ndp.sendDADPacket(addr, ref) - ndp.nic.mu.Lock() - } - - if done { - // If we reach this point, it means that DAD was stopped after we released - // the NIC's read lock and before we obtained the write lock. - return - } - - if dadDone { - // DAD has resolved. - ref.setKind(permanent) - } else if err == nil { - // DAD is not done and we had no errors when sending the last NDP NS, - // schedule the next DAD timer. - remaining-- - timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - return - } - - // At this point we know that either DAD is done or we hit an error sending - // the last NDP NS. Either way, clean up addr's DAD state and let the - // integrator know DAD has completed. - delete(ndp.dad, addr) - - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) - } - - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && ref.configType == slaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) - } - }) - - ndp.dad[addr] = dadState{ - timer: timer, - done: &done, - } - - return nil -} - -// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns -// addr. -// -// addr must be a tentative IPv6 address on ndp's NIC. -// -// The NIC ndp belongs to MUST NOT be locked. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { - snmc := header.SolicitedNodeAddr(addr) - - r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) - defer r.Release() - - // Route should resolve immediately since snmc is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. - if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { - return err - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) - } - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, - ); err != nil { - sent.Dropped.Increment() - return err - } - sent.NeighborSolicit.Increment() - - return nil -} - -// stopDuplicateAddressDetection ends a running Duplicate Address Detection -// process. Note, this may leave the DAD process for a tentative address in -// such a state forever, unless some other external event resolves the DAD -// process (receiving an NA from the true owner of addr, or an NS for addr -// (implying another node is attempting to use addr)). It is up to the caller -// of this function to handle such a scenario. Normally, addr will be removed -// from n right after this function returns or the address successfully -// resolved. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { - dad, ok := ndp.dad[addr] - if !ok { - // Not currently performing DAD on addr, just return. - return - } - - if dad.timer != nil { - dad.timer.Stop() - dad.timer = nil - - *dad.done = true - dad.done = nil - } - - delete(ndp.dad, addr) - - // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) - } -} - -// handleRA handles a Router Advertisement message that arrived on the NIC -// this ndp is for. Does nothing if the NIC is configured to not handle RAs. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - // Is the NIC configured to handle RAs at all? - // - // Currently, the stack does not determine router interface status on a - // per-interface basis; it is a stack-wide configuration, so we check - // stack's forwarding flag to determine if the NIC is a routing - // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { - return - } - - // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we - // only inform the dispatcher on configuration changes. We do nothing else - // with the information. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - var configuration DHCPv6ConfigurationFromNDPRA - switch { - case ra.ManagedAddrConfFlag(): - configuration = DHCPv6ManagedAddress - - case ra.OtherConfFlag(): - configuration = DHCPv6OtherConfigurations - - default: - configuration = DHCPv6NoConfiguration - } - - if ndp.dhcpv6Configuration != configuration { - ndp.dhcpv6Configuration = configuration - ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) - } - } - - // Is the NIC configured to discover default routers? - if ndp.configs.DiscoverDefaultRouters { - rtr, ok := ndp.defaultRouters[ip] - rl := ra.RouterLifetime() - switch { - case !ok && rl != 0: - // This is a new default router we are discovering. - // - // Only remember it if we currently know about less than - // MaxDiscoveredDefaultRouters routers. - if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { - ndp.rememberDefaultRouter(ip, rl) - } - - case ok && rl != 0: - // This is an already discovered default router. Update - // the invalidation job. - rtr.invalidationJob.Cancel() - rtr.invalidationJob.Schedule(rl) - ndp.defaultRouters[ip] = rtr - - case ok && rl == 0: - // We know about the router but it is no longer to be - // used as a default router so invalidate it. - ndp.invalidateDefaultRouter(ip) - } - } - - // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter - // Discovery. - - // We know the options is valid as far as wire format is concerned since - // we got the Router Advertisement, as documented by this fn. Given this - // we do not check the iterator for errors on calls to Next. - it, _ := ra.Options().Iter(false) - for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { - switch opt := opt.(type) { - case header.NDPRecursiveDNSServer: - if ndp.nic.stack.ndpDisp == nil { - continue - } - - addrs, _ := opt.Addresses() - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) - - case header.NDPDNSSearchList: - if ndp.nic.stack.ndpDisp == nil { - continue - } - - domainNames, _ := opt.DomainNames() - ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) - - case header.NDPPrefixInformation: - prefix := opt.Subnet() - - // Is the prefix a link-local? - if header.IsV6LinkLocalAddress(prefix.ID()) { - // ...Yes, skip as per RFC 4861 section 6.3.4, - // and RFC 4862 section 5.5.3.b (for SLAAC). - continue - } - - // Is the Prefix Length 0? - if prefix.Prefix() == 0 { - // ...Yes, skip as this is an invalid prefix - // as all IPv6 addresses cannot be on-link. - continue - } - - if opt.OnLinkFlag() { - ndp.handleOnLinkPrefixInformation(opt) - } - - if opt.AutonomousAddressConfigurationFlag() { - ndp.handleAutonomousPrefixInformation(opt) - } - } - - // TODO(b/141556115): Do (MTU) Parameter Discovery. - } -} - -// invalidateDefaultRouter invalidates a discovered default router. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { - rtr, ok := ndp.defaultRouters[ip] - - // Is the router still discovered? - if !ok { - // ...Nope, do nothing further. - return - } - - rtr.invalidationJob.Cancel() - delete(ndp.defaultRouters, ip) - - // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) - } -} - -// rememberDefaultRouter remembers a newly discovered default router with IPv6 -// link-local address ip with lifetime rl. -// -// The router identified by ip MUST NOT already be known by the NIC. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp - if ndpDisp == nil { - return - } - - // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { - // Informed by the integrator to not remember the router, do - // nothing further. - return - } - - state := defaultRouterState{ - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - ndp.invalidateDefaultRouter(ip) - }), - } - - state.invalidationJob.Schedule(rl) - - ndp.defaultRouters[ip] = state -} - -// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 -// address with prefix prefix with lifetime l. -// -// The prefix identified by prefix MUST NOT already be known. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.nic.stack.ndpDisp - if ndpDisp == nil { - return - } - - // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { - // Informed by the integrator to not remember the prefix, do - // nothing further. - return - } - - state := onLinkPrefixState{ - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - ndp.invalidateOnLinkPrefix(prefix) - }), - } - - if l < header.NDPInfiniteLifetime { - state.invalidationJob.Schedule(l) - } - - ndp.onLinkPrefixes[prefix] = state -} - -// invalidateOnLinkPrefix invalidates a discovered on-link prefix. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { - s, ok := ndp.onLinkPrefixes[prefix] - - // Is the on-link prefix still discovered? - if !ok { - // ...Nope, do nothing further. - return - } - - s.invalidationJob.Cancel() - delete(ndp.onLinkPrefixes, prefix) - - // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) - } -} - -// handleOnLinkPrefixInformation handles a Prefix Information option with -// its on-link flag set, as per RFC 4861 section 6.3.4. -// -// handleOnLinkPrefixInformation assumes that the prefix this pi is for is -// not the link-local prefix and the on-link flag is set. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { - prefix := pi.Subnet() - prefixState, ok := ndp.onLinkPrefixes[prefix] - vl := pi.ValidLifetime() - - if !ok && vl == 0 { - // Don't know about this prefix but it has a zero valid - // lifetime, so just ignore. - return - } - - if !ok && vl != 0 { - // This is a new on-link prefix we are discovering - // - // Only remember it if we currently know about less than - // MaxDiscoveredOnLinkPrefixes on-link prefixes. - if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { - ndp.rememberOnLinkPrefix(prefix, vl) - } - return - } - - if ok && vl == 0 { - // We know about the on-link prefix, but it is - // no longer to be considered on-link, so - // invalidate it. - ndp.invalidateOnLinkPrefix(prefix) - return - } - - // This is an already discovered on-link prefix with a - // new non-zero valid lifetime. - // - // Update the invalidation job. - - prefixState.invalidationJob.Cancel() - - if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, schedule the job to execute after - // the new valid lifetime. - prefixState.invalidationJob.Schedule(vl) - } - - ndp.onLinkPrefixes[prefix] = prefixState -} - -// handleAutonomousPrefixInformation handles a Prefix Information option with -// its autonomous flag set, as per RFC 4862 section 5.5.3. -// -// handleAutonomousPrefixInformation assumes that the prefix this pi is for is -// not the link-local prefix and the autonomous flag is set. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { - vl := pi.ValidLifetime() - pl := pi.PreferredLifetime() - - // If the preferred lifetime is greater than the valid lifetime, - // silently ignore the Prefix Information option, as per RFC 4862 - // section 5.5.3.c. - if pl > vl { - return - } - - prefix := pi.Subnet() - - // Check if we already maintain SLAAC state for prefix. - if state, ok := ndp.slaacPrefixes[prefix]; ok { - // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. - ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl) - ndp.slaacPrefixes[prefix] = state - return - } - - // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section - // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC. - if !ndp.configs.AutoGenGlobalAddresses { - return - } - - ndp.doSLAAC(prefix, pl, vl) -} - -// doSLAAC generates a new SLAAC address with the provided lifetimes -// for prefix. -// -// pl is the new preferred lifetime. vl is the new valid lifetime. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { - // If we do not already have an address for this prefix and the valid - // lifetime is 0, no need to do anything further, as per RFC 4862 - // section 5.5.3.d. - if vl == 0 { - return - } - - // Make sure the prefix is valid (as far as its length is concerned) to - // generate a valid IPv6 address from an interface identifier (IID), as - // per RFC 4862 sectiion 5.5.3.d. - if prefix.Prefix() != validPrefixLenForAutoGen { - return - } - - state := slaacPrefixState{ - deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) - } - - ndp.deprecateSLAACAddress(state.stableAddr.ref) - }), - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) - } - - ndp.invalidateSLAACPrefix(prefix, state) - }), - tempAddrs: make(map[tcpip.Address]tempSLAACAddrState), - maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, - } - - now := time.Now() - - // The time an address is preferred until is needed to properly generate the - // address. - if pl < header.NDPInfiniteLifetime { - state.preferredUntil = now.Add(pl) - } - - if !ndp.generateSLAACAddr(prefix, &state) { - // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or jobs for a prefix we - // do not have an address for. - return - } - - // Setup the initial jobs to deprecate and invalidate prefix. - - if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationJob.Schedule(pl) - } - - if vl < header.NDPInfiniteLifetime { - state.invalidationJob.Schedule(vl) - state.validUntil = now.Add(vl) - } - - // If the address is assigned (DAD resolved), generate a temporary address. - if state.stableAddr.ref.getKind() == permanent { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) - } - - ndp.slaacPrefixes[prefix] = state -} - -// addSLAACAddr adds a SLAAC address to the NIC. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { - // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp - if ndpDisp == nil { - return nil - } - - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { - // Informed by the integrator not to add the address. - return nil - } - - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr, - } - - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) - if err != nil { - panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) - } - - return ref -} - -// generateSLAACAddr generates a SLAAC address for prefix. -// -// Returns true if an address was successfully generated. -// -// Panics if the prefix is not a SLAAC prefix or it already has an address. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.stableAddr.ref; r != nil { - panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) - } - - // If we have already reached the maximum address generation attempts for the - // prefix, do not generate another address. - if state.generationAttempts == state.maxGenerationAttempts { - return false - } - - var generatedAddr tcpip.AddressWithPrefix - addrBytes := []byte(prefix.ID()) - - for i := 0; ; i++ { - // If we were unable to generate an address after the maximum SLAAC address - // local regeneration attempts, do nothing further. - if i == maxSLAACAddrLocalRegenAttempts { - return false - } - - dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier( - addrBytes[:header.IIDOffsetInIPv6Address], - prefix, - oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), - dadCounter, - oIID.SecretKey, - ) - } else if dadCounter == 0 { - // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if - // the DAD counter is non-zero, we cannot use this method. - // - // Only attempt to generate an interface-specific IID if we have a valid - // link address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.nic.linkEP.LinkAddress() - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return false - } - - // Generate an address within prefix from the modified EUI-64 of ndp's - // NIC's Ethernet MAC address. - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - } else { - // We have no way to regenerate an address in response to an address - // conflict when addresses are not generated with opaque IIDs. - return false - } - - generatedAddr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: validPrefixLenForAutoGen, - } - - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { - break - } - - state.stableAddr.localGenerationFailures++ - } - - if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { - state.stableAddr.ref = ref - state.generationAttempts++ - return true - } - - return false -} - -// regenerateSLAACAddr regenerates an address for a SLAAC prefix. -// -// If generating a new address for the prefix fails, the prefix will be -// invalidated. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix)) - } - - if ndp.generateSLAACAddr(prefix, &state) { - ndp.slaacPrefixes[prefix] = state - return - } - - // We were unable to generate a permanent address for the SLAAC prefix so - // invalidate the prefix as there is no reason to maintain state for a - // SLAAC prefix we do not have an address for. - ndp.invalidateSLAACPrefix(prefix, state) -} - -// generateTempSLAACAddr generates a new temporary SLAAC address. -// -// If resetGenAttempts is true, the prefix's generation counter will be reset. -// -// Returns true if a new address was generated. -func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { - // Are we configured to auto-generate new temporary global addresses for the - // prefix? - if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() { - return false - } - - if resetGenAttempts { - prefixState.generationAttempts = 0 - prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1 - } - - // If we have already reached the maximum address generation attempts for the - // prefix, do not generate another address. - if prefixState.generationAttempts == prefixState.maxGenerationAttempts { - return false - } - - stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress - now := time.Now() - - // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary - // address is the lower of the valid lifetime of the stable address or the - // maximum temporary address valid lifetime. - vl := ndp.configs.MaxTempAddrValidLifetime - if prefixState.validUntil != (time.Time{}) { - if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { - vl = prefixVL - } - } - - if vl <= 0 { - // Cannot create an address without a valid lifetime. - return false - } - - // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary - // address is the lower of the preferred lifetime of the stable address or the - // maximum temporary address preferred lifetime - the temporary address desync - // factor. - pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor - if prefixState.preferredUntil != (time.Time{}) { - if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { - // Respect the preferred lifetime of the prefix, as per RFC 4941 section - // 3.3 step 4. - pl = prefixPL - } - } - - // As per RFC 4941 section 3.3 step 5, a temporary address is created only if - // the calculated preferred lifetime is greater than the advance regeneration - // duration. In particular, we MUST NOT create a temporary address with a zero - // Preferred Lifetime. - if pl <= ndp.configs.RegenAdvanceDuration { - return false - } - - // Attempt to generate a new address that is not already assigned to the NIC. - var generatedAddr tcpip.AddressWithPrefix - for i := 0; ; i++ { - // If we were unable to generate an address after the maximum SLAAC address - // local regeneration attempts, do nothing further. - if i == maxSLAACAddrLocalRegenAttempts { - return false - } - - generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) - if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { - break - } - } - - // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary - // address with a zero preferred lifetime. The checks above ensure this - // so we know the address is not deprecated. - ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) - if ref == nil { - return false - } - - state := tempSLAACAddrState{ - deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - prefixState, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) - } - - tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] - if !ok { - panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) - } - - ndp.deprecateSLAACAddress(tempAddrState.ref) - }), - invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - prefixState, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) - } - - tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] - if !ok { - panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr)) - } - - ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) - }), - regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { - prefixState, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) - } - - tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] - if !ok { - panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr)) - } - - // If an address has already been regenerated for this address, don't - // regenerate another address. - if tempAddrState.regenerated { - return - } - - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */) - prefixState.tempAddrs[generatedAddr.Address] = tempAddrState - ndp.slaacPrefixes[prefix] = prefixState - }), - createdAt: now, - ref: ref, - } - - state.deprecationJob.Schedule(pl) - state.invalidationJob.Schedule(vl) - state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) - - prefixState.generationAttempts++ - prefixState.tempAddrs[generatedAddr.Address] = state - - return true -} - -// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix)) - } - - ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts) - ndp.slaacPrefixes[prefix] = state -} - -// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. -// -// pl is the new preferred lifetime. vl is the new valid lifetime. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { - // If the preferred lifetime is zero, then the prefix should be deprecated. - deprecated := pl == 0 - if deprecated { - ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) - } else { - prefixState.stableAddr.ref.deprecated = false - } - - // If prefix was preferred for some finite lifetime before, cancel the - // deprecation job so it can be reset. - prefixState.deprecationJob.Cancel() - - now := time.Now() - - // Schedule the deprecation job if prefix has a finite preferred lifetime. - if pl < header.NDPInfiniteLifetime { - if !deprecated { - prefixState.deprecationJob.Schedule(pl) - } - prefixState.preferredUntil = now.Add(pl) - } else { - prefixState.preferredUntil = time.Time{} - } - - // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: - // - // 1) If the received Valid Lifetime is greater than 2 hours or greater than - // RemainingLifetime, set the valid lifetime of the prefix to the - // advertised Valid Lifetime. - // - // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the - // advertised Valid Lifetime. - // - // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. - - if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not schedule a - // job in this case. - prefixState.invalidationJob.Cancel() - prefixState.validUntil = time.Time{} - } else { - var effectiveVl time.Duration - var rl time.Duration - - // If the prefix was originally set to be valid forever, assume the - // remaining time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(prefixState.validUntil) - } - - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl > MinPrefixInformationValidLifetimeForUpdate { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate - } - - if effectiveVl != 0 { - prefixState.invalidationJob.Cancel() - prefixState.invalidationJob.Schedule(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) - } - } - - // If DAD is not yet complete on the stable address, there is no need to do - // work with temporary addresses. - if prefixState.stableAddr.ref.getKind() != permanent { - return - } - - // Note, we do not need to update the entries in the temporary address map - // after updating the jobs because the jobs are held as pointers. - var regenForAddr tcpip.Address - allAddressesRegenerated := true - for tempAddr, tempAddrState := range prefixState.tempAddrs { - // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary - // address is the lower of the valid lifetime of the stable address or the - // maximum temporary address valid lifetime. Note, the valid lifetime of a - // temporary address is relative to the address's creation time. - validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) - if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { - validUntil = prefixState.validUntil - } - - // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation job. - newValidLifetime := validUntil.Sub(now) - if newValidLifetime <= 0 { - ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) - continue - } - tempAddrState.invalidationJob.Cancel() - tempAddrState.invalidationJob.Schedule(newValidLifetime) - - // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary - // address is the lower of the preferred lifetime of the stable address or - // the maximum temporary address preferred lifetime - the temporary address - // desync factor. Note, the preferred lifetime of a temporary address is - // relative to the address's creation time. - preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) - if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { - preferredUntil = prefixState.preferredUntil - } - - // If the address is no longer preferred, deprecate it immediately. - // Otherwise, schedule the deprecation job again. - newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationJob.Cancel() - if newPreferredLifetime <= 0 { - ndp.deprecateSLAACAddress(tempAddrState.ref) - } else { - tempAddrState.ref.deprecated = false - tempAddrState.deprecationJob.Schedule(newPreferredLifetime) - } - - tempAddrState.regenJob.Cancel() - if tempAddrState.regenerated { - } else { - allAddressesRegenerated = false - - if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration { - // The new preferred lifetime is less than the advance regeneration - // duration so regenerate an address for this temporary address - // immediately after we finish iterating over the temporary addresses. - regenForAddr = tempAddr - } else { - tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) - } - } - } - - // Generate a new temporary address if all of the existing temporary addresses - // have been regenerated, or we need to immediately regenerate an address - // due to an update in preferred lifetime. - // - // If each temporay address has already been regenerated, no new temporary - // address will be generated. To ensure continuation of temporary SLAAC - // addresses, we manually try to regenerate an address here. - if len(regenForAddr) != 0 || allAddressesRegenerated { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok { - state.regenerated = true - prefixState.tempAddrs[regenForAddr] = state - } - } -} - -// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP -// dispatcher that ref has been deprecated. -// -// deprecateSLAACAddress does nothing if ref is already deprecated. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { - if ref.deprecated { - return - } - - ref.deprecated = true - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) - } -} - -// invalidateSLAACPrefix invalidates a SLAAC prefix. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.stableAddr.ref; r != nil { - // Since we are already invalidating the prefix, do not invalidate the - // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) - } - } - - ndp.cleanupSLAACPrefixResources(prefix, state) -} - -// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's -// resources. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) - } - - prefix := addr.Subnet() - state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress { - return - } - - if !invalidatePrefix { - // If the prefix is not being invalidated, disassociate the address from the - // prefix and do nothing further. - state.stableAddr.ref = nil - ndp.slaacPrefixes[prefix] = state - return - } - - ndp.cleanupSLAACPrefixResources(prefix, state) -} - -// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. -// -// Panics if the SLAAC prefix is not known. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { - // Invalidate all temporary addresses. - for tempAddr, tempAddrState := range state.tempAddrs { - ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) - } - - state.stableAddr.ref = nil - state.deprecationJob.Cancel() - state.invalidationJob.Cancel() - delete(ndp.slaacPrefixes, prefix) -} - -// invalidateTempSLAACAddr invalidates a temporary SLAAC address. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - // Since we are already invalidating the address, do not invalidate the - // address when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { - panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) - } - - ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) -} - -// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary -// SLAAC address's resources from ndp. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) - } - - if !invalidateAddr { - return - } - - prefix := addr.Subnet() - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr)) - } - - tempAddrState, ok := state.tempAddrs[addr.Address] - if !ok { - panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr)) - } - - ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState) -} - -// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// jobs and entry. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationJob.Cancel() - tempAddrState.invalidationJob.Cancel() - tempAddrState.regenJob.Cancel() - delete(tempAddrs, tempAddr) -} - -// cleanupState cleans up ndp's state. -// -// If hostOnly is true, then only host-specific state will be cleaned up. -// -// cleanupState MUST be called with hostOnly set to true when ndp's NIC is -// transitioning from a host to a router. This function will invalidate all -// discovered on-link prefixes, discovered routers, and auto-generated -// addresses. -// -// If hostOnly is true, then the link-local auto-generated address will not be -// invalidated as routers are also expected to generate a link-local address. -// -// The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupState(hostOnly bool) { - linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalPrefixes := 0 - for prefix, state := range ndp.slaacPrefixes { - // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them if we are cleaning up - // host-only state. - if hostOnly && prefix == linkLocalSubnet { - linkLocalPrefixes++ - continue - } - - ndp.invalidateSLAACPrefix(prefix, state) - } - - if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) - } - - for prefix := range ndp.onLinkPrefixes { - ndp.invalidateOnLinkPrefix(prefix) - } - - if got := len(ndp.onLinkPrefixes); got != 0 { - panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) - } - - for router := range ndp.defaultRouters { - ndp.invalidateDefaultRouter(router) - } - - if got := len(ndp.defaultRouters); got != 0 { - panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) - } - - ndp.dhcpv6Configuration = 0 -} - -// startSolicitingRouters starts soliciting routers, as per RFC 4861 section -// 6.3.7. If routers are already being solicited, this function does nothing. -// -// The NIC ndp belongs to MUST be locked. -func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicit.timer != nil { - // We are already soliciting routers. - return - } - - remaining := ndp.configs.MaxRtrSolicitations - if remaining == 0 { - return - } - - // Calculate the random delay before sending our first RS, as per RFC - // 4861 section 6.3.7. - var delay time.Duration - if ndp.configs.MaxRtrSolicitationDelay > 0 { - delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) - } - - var done bool - ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { - ndp.nic.mu.Lock() - if done { - // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the NIC lock and stopped solicitations. - // Simply return here and do nothing further. - ndp.nic.mu.Unlock() - return - } - - // As per RFC 4861 section 4.1, the source of the RS is an address assigned - // to the sending interface, or the unspecified address if no address is - // assigned to the sending interface. - ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) - if ref == nil { - ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) - } - ndp.nic.mu.Unlock() - - localAddr := ref.ep.ID().LocalAddress - r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) - defer r.Release() - - // Route should resolve immediately since - // header.IPv6AllRoutersMulticastAddress is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the NIC is not locked, - // the NIC could have been disabled or removed by another goroutine. - if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { - return - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) - } - - // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source - // link-layer address option if the source address of the NDP RS is - // specified. This option MUST NOT be included if the source address is - // unspecified. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - var optsSerializer header.NDPOptionsSerializer - if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { - optsSerializer = header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), - } - } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) - pkt := header.ICMPv6(hdr.Prepend(payloadSize)) - pkt.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(pkt.NDPPayload()) - rs.Options().Serialize(optsSerializer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, - ); err != nil { - sent.Dropped.Increment() - log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) - // Don't send any more messages if we had an error. - remaining = 0 - } else { - sent.RouterSolicit.Increment() - remaining-- - } - - ndp.nic.mu.Lock() - if done || remaining == 0 { - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil - } else if ndp.rtrSolicit.timer != nil { - // Note, we need to explicitly check to make sure that - // the timer field is not nil because if it was nil but - // we still reached this point, then we know the NIC - // was requested to stop soliciting routers so we don't - // need to send the next Router Solicitation message. - ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) - } - ndp.nic.mu.Unlock() - }) - -} - -// stopSolicitingRouters stops soliciting routers. If routers are not currently -// being solicited, this function does nothing. -// -// The NIC ndp belongs to MUST be locked. -func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicit.timer == nil { - // Nothing to do. - return - } - - *ndp.rtrSolicit.done = true - ndp.rtrSolicit.timer.Stop() - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil -} - -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) - - if MaxDesyncFactor != 0 { - ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 644ba7c33..73a01c2dd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -150,10 +150,10 @@ type ndpDNSSLEvent struct { type ndpDHCPv6Event struct { nicID tcpip.NICID - configuration stack.DHCPv6ConfigurationFromNDPRA + configuration ipv6.DHCPv6ConfigurationFromNDPRA } -var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) +var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. @@ -170,7 +170,7 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ @@ -182,7 +182,7 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } } -// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -195,7 +195,7 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip. return n.rememberRouter } -// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { if c := n.routerC; c != nil { c <- ndpRouterEvent{ @@ -206,7 +206,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip } } -// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -219,7 +219,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip return n.rememberPrefix } -// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +// Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ @@ -261,7 +261,7 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi } } -// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +// Implements ipv6.NDPDispatcher.OnRecursiveDNSServerOption. func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { if c := n.rdnssC; c != nil { c <- ndpRDNSSEvent{ @@ -274,7 +274,7 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } -// Implements stack.NDPDispatcher.OnDNSSearchListOption. +// Implements ipv6.NDPDispatcher.OnDNSSearchListOption. func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { if n.dnsslC != nil { n.dnsslC <- ndpDNSSLEvent{ @@ -285,8 +285,8 @@ func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []s } } -// Implements stack.NDPDispatcher.OnDHCPv6Configuration. -func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { +// Implements ipv6.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration ipv6.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { c <- ndpDHCPv6Event{ nicID, @@ -319,13 +319,12 @@ func TestDADDisabled(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -413,19 +412,21 @@ func TestDADResolve(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = test.retransTimer - opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits e := channelLinkWithHeaderLength{ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ipv6.NDPConfigurations{ + RetransmitTimer: test.retransTimer, + DupAddrDetectTransmits: test.dupAddrDetectTransmits, + }, + })}, + }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -541,7 +542,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -550,14 +551,34 @@ func TestDADResolve(t *testing.T) { checker.NDPNSOptions(nil), )) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } }) } } +func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(tgt) + snmc := header.SolicitedNodeAddr(tgt) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, + }) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) +} + // TestDADFail tests to make sure that the DAD process fails if another node is // detected to be performing DAD on the same address (receive an NS message from // a node doing DAD for the same address), or if another node is detected to own @@ -567,39 +588,19 @@ func TestDADFail(t *testing.T) { tests := []struct { name string - makeBuf func(tgt tcpip.Address) buffer.Prependable + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "RxSolicit", - func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(tgt) - snmc := header.SolicitedNodeAddr(tgt) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, - }) - - return hdr - - }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RxSolicit", + rxPkt: rxNDPSolicit, + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, }, { - "RxAdvert", - func(tgt tcpip.Address) buffer.Prependable { + name: "RxAdvert", + rxPkt: func(e *channel.Endpoint, tgt tcpip.Address) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) @@ -621,11 +622,9 @@ func TestDADFail(t *testing.T) { SrcAddr: tgt, DstAddr: header.IPv6AllNodesMulticastAddress, }) - - return hdr - + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, - func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, }, @@ -636,16 +635,16 @@ func TestDADFail(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - } - opts.NDPConfigs.RetransmitTimer = time.Second * 2 + ndpConfigs := ipv6.DefaultNDPConfigurations() + ndpConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -664,12 +663,8 @@ func TestDADFail(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Receive a packet to simulate multiple nodes owning or - // attempting to own the same address. - hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) + // Receive a packet to simulate an address conflict. + test.rxPkt(e, addr1) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -753,18 +748,19 @@ func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } - ndpConfigs := stack.NDPConfigurations{ + + ndpConfigs := ipv6.NDPConfigurations{ RetransmitTimer: time.Second, DupAddrDetectTransmits: 2, } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + })}, + }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -814,19 +810,6 @@ func TestDADStop(t *testing.T) { } } -// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if -// we attempt to update NDP configurations using an invalid NICID. -func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - - // No NIC with ID 1 yet. - if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID { - t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID) - } -} - // TestSetNDPConfigurations tests that we can update and use per-interface NDP // configurations without affecting the default NDP configurations or other // interfaces' configurations. @@ -862,8 +845,9 @@ func TestSetNDPConfigurations(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPDisp: &ndpDisp, + })}, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -891,12 +875,15 @@ func TestSetNDPConfigurations(t *testing.T) { } // Update the NDP configurations on NIC(1) to use DAD. - configs := stack.NDPConfigurations{ + configs := ipv6.NDPConfigurations{ DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(nicID1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID1, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(configs) } // Created after updating NIC(1)'s NDP configurations @@ -1024,7 +1011,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) } // raBufWithOpts returns a valid NDP Router Advertisement with options. @@ -1110,14 +1099,15 @@ func TestNoRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverDefaultRouters: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1148,12 +1138,13 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1189,12 +1180,13 @@ func TestRouterDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func(addr tcpip.Address, discovered bool) { @@ -1282,7 +1274,7 @@ func TestRouterDiscovery(t *testing.T) { } // TestRouterDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), @@ -1290,12 +1282,13 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1303,14 +1296,14 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) - if i <= stack.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredDefaultRouters { select { case e := <-ndpDisp.routerC: if diff := checkRouterEvent(e, llAddr, true); diff != "" { @@ -1355,14 +1348,15 @@ func TestNoPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - DiscoverOnLinkPrefixes: discover, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1396,13 +1390,14 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1442,12 +1437,13 @@ func TestPrefixDiscovery(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1542,12 +1538,13 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1618,33 +1615,34 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { } // TestPrefixDiscoveryMaxRouters tests that only -// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +// ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: false, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) } - optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + optSer := make(header.NDPOptionsSerializer, ipv6.MaxDiscoveredOnLinkPrefixes+2) + prefixes := [ipv6.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link // prefixes. - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} prefixAddr[7] = byte(i) prefix := tcpip.AddressWithPrefix{ @@ -1662,8 +1660,8 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) - for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { - if i < stack.MaxDiscoveredOnLinkPrefixes { + for i := 0; i < ipv6.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < ipv6.MaxDiscoveredOnLinkPrefixes { select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { @@ -1689,13 +1687,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) AddressWithPrefix: item, } - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false + return containsAddr(list, protocolAddress) } // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. @@ -1719,14 +1711,15 @@ func TestNoAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: handle, - AutoGenGlobalAddresses: autogen, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + })}, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1752,14 +1745,14 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. -func TestAutoGenAddr(t *testing.T) { +func TestAutoGenAddr2(t *testing.T) { const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1769,12 +1762,13 @@ func TestAutoGenAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -1879,14 +1873,14 @@ func TestAutoGenTempAddr(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := stack.MaxDesyncFactor + savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + ipv6.MaxDesyncFactor = time.Nanosecond prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1934,16 +1928,17 @@ func TestAutoGenTempAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2122,11 +2117,11 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond tests := []struct { name string @@ -2163,12 +2158,13 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2214,11 +2210,11 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMaxDesyncFactor := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MaxDesyncFactor = savedMaxDesyncFactor }() - stack.MaxDesyncFactor = 0 + ipv6.MaxDesyncFactor = 0 prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2231,15 +2227,16 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2297,17 +2294,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2320,16 +2317,17 @@ func TestAutoGenTempAddrRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2385,8 +2383,11 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Wait for all the temporary addresses to get invalidated. @@ -2442,17 +2443,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { newMinVLDuration = newMinVL * time.Second ) - savedMaxDesyncFactor := stack.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + savedMaxDesyncFactor := ipv6.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime defer func() { - stack.MaxDesyncFactor = savedMaxDesyncFactor - stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + ipv6.MaxDesyncFactor = savedMaxDesyncFactor + ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime }() - stack.MaxDesyncFactor = 0 - stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration - stack.MinMaxTempAddrValidLifetime = newMinVLDuration + ipv6.MaxDesyncFactor = 0 + ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration + ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte @@ -2465,16 +2466,17 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, RegenAdvanceDuration: newMinVLDuration - regenAfter, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2548,9 +2550,12 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // as paased. ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) select { case e := <-ndpDisp.autoGenAddrC: @@ -2568,9 +2573,7 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) - } + ndpEP.SetNDPConfigurations(ndpConfigs) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } @@ -2658,20 +2661,21 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) - ndpConfigs := stack.NDPConfigurations{ + ndpConfigs := ipv6.NDPConfigurations{ HandleRAs: true, AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: test.tempAddrs, AutoGenAddressConflictRetries: 1, } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: test.nicNameFromID, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) s.SetRouteTable([]tcpip.Route{{ @@ -2742,8 +2746,11 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { ndpDisp.dadC = make(chan ndpDADEvent, 2) ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits ndpConfigs.RetransmitTimer = retransmitTimer - if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { - t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + if ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else { + ndpEP := ipv6Ep.(ipv6.NDPEndpoint) + ndpEP.SetNDPConfigurations(ndpConfigs) } // Do SLAAC for prefix. @@ -2757,9 +2764,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // DAD failure to restart the local generation process. addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] expectAutoGenAddrAsyncEvent(addr, newAddr) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { @@ -2790,20 +2795,22 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2813,7 +2820,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd Gateway: llAddr3, NIC: nicID, }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if useNeighborCache { + s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + } else { + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + } return ndpDisp, e, s } @@ -2887,110 +2898,128 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA // TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when // receiving a PI with 0 preferred lifetime. func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + const nicID = 1 - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - default: - t.Fatal("expected addr auto gen event") - } - } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + }) } - expectPrimaryAddr(addr2) } // TestAutoGenAddrJobDeprecation tests that an address is properly deprecated @@ -2999,217 +3028,236 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate + defer func() { + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved + }() + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - default: - t.Fatal("expected addr auto gen event") - } - } - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). select { case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") } - case <-time.After(defaultAsyncPositiveEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } - } else { - t.Fatalf("got unexpected auto-generated event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } + }) } } @@ -3219,12 +3267,12 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 const minVLSeconds = 1 savedIL := header.NDPInfiniteLifetime - savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL header.NDPInfiniteLifetime = savedIL }() - stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3268,12 +3316,13 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3318,11 +3367,11 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 const newMinVL = 4 - saved := stack.MinPrefixInformationValidLifetimeForUpdate + saved := ipv6.MinPrefixInformationValidLifetimeForUpdate defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved + ipv6.MinPrefixInformationValidLifetimeForUpdate = saved }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3410,12 +3459,13 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { } e := channel.New(10, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3476,12 +3526,13 @@ func TestAutoGenAddrRemoval(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3527,110 +3578,128 @@ func TestAutoGenAddrRemoval(t *testing.T) { func TestAutoGenAddrAfterRemoval(t *testing.T) { const nicID = 1 - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() - - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) + + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) + }) } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) } // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that @@ -3643,12 +3712,13 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { @@ -3724,18 +3794,19 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3799,11 +3870,11 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const lifetimeSeconds = 10 // Needed for the temporary address sub test. - savedMaxDesync := stack.MaxDesyncFactor + savedMaxDesync := ipv6.MaxDesyncFactor defer func() { - stack.MaxDesyncFactor = savedMaxDesync + ipv6.MaxDesyncFactor = savedMaxDesync }() - stack.MaxDesyncFactor = time.Nanosecond + ipv6.MaxDesyncFactor = time.Nanosecond var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] @@ -3881,14 +3952,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -3906,7 +3977,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, }, @@ -3920,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }, { name: "Temporary address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -3972,16 +4043,17 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, }, - SecretKey: secretKey, - }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4002,9 +4074,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) expectDADEvent(t, &ndpDisp, addr.Address, false) @@ -4062,14 +4132,14 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrTypes := []struct { name string - ndpConfigs stack.NDPConfigurations + ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool subnet tcpip.Subnet triggerSLAACFn func(e *channel.Endpoint) }{ { name: "Global address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, HandleRAs: true, @@ -4085,7 +4155,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { }, { name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ + ndpConfigs: ipv6.NDPConfigurations{ DupAddrDetectTransmits: dadTransmits, RetransmitTimer: retransmitTimer, AutoGenAddressConflictRetries: maxRetries, @@ -4108,10 +4178,11 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4141,9 +4212,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4193,21 +4262,22 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, }, - SecretKey: secretKey, - }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + })}, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4239,9 +4309,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Simulate a DAD conflict after some time has passed. time.Sleep(failureTimer) - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } + rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: @@ -4402,11 +4470,12 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4452,11 +4521,12 @@ func TestNDPDNSSearchListDispatch(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4583,7 +4653,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -4637,15 +4707,16 @@ func TestCleanupNDPState(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, }) expectRouterEvent := func() (bool, ndpRouterEvent) { @@ -4910,18 +4981,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + expectDHCPv6Event := func(configuration ipv6.DHCPv6ConfigurationFromNDPRA) { t.Helper() select { case e := <-ndpDisp.dhcpv6ConfigurationC: @@ -4945,7 +5017,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Even if the first RA reports no DHCPv6 configurations are available, the // dispatcher should get an event. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) // Receiving the same update again should not result in an event to the // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) @@ -4954,19 +5026,19 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to Managed Address. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) expectNoDHCPv6Event() // Receive an RA that updates the DHCPv6 configuration to none. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) - expectDHCPv6Event(stack.DHCPv6NoConfiguration) + expectDHCPv6Event(ipv6.DHCPv6NoConfiguration) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -4974,7 +5046,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // // Note, when the M flag is set, the O flag is redundant. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) - expectDHCPv6Event(stack.DHCPv6ManagedAddress) + expectDHCPv6Event(ipv6.DHCPv6ManagedAddress) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) expectNoDHCPv6Event() // Even though the DHCPv6 flags are different, the effective configuration is @@ -4987,7 +5059,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -5002,7 +5074,7 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Receive an RA that updates the DHCPv6 configuration to Other // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) - expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + expectDHCPv6Event(ipv6.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() } @@ -5140,16 +5212,15 @@ func TestRouterSolicitation(t *testing.T) { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } - checker.IPv6(t, - p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } waitForNothing := func(timeout time.Duration) { @@ -5161,12 +5232,13 @@ func TestRouterSolicitation(t *testing.T) { } } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, }) if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5230,11 +5302,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, @@ -5294,19 +5366,20 @@ func TestStopStartSolicitingRouters(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, - }, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + })}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 1d37716c2..27e1feec0 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -115,17 +115,15 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { - if linkRes != nil { - if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { - e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - } - return e, nil, nil + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), } + return e, nil, nil } entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) @@ -289,8 +287,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, nil) +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 4cb2c9c6b..a0b7da5cd 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) const ( @@ -239,7 +240,7 @@ type entryEvent struct { func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) if got, want := neigh.config(), c; got != want { @@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) { func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) c.MinRandomFactor = 1 @@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) { func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -335,37 +336,11 @@ func TestNeighborCacheEntry(t *testing.T) { } } -// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a -// LinkAddressResolver returns ErrNoLinkAddress. -func TestNeighborCacheEntryNoLinkAddress(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := newFakeClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil) - if err != tcpip.ErrNoLinkAddress { - t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheRemoveEntry(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -384,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -435,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } type testContext struct { - clock *fakeClock + clock *faketime.ManualClock neigh *neighborCache store *testEntryStore linkRes *testNeighborResolver @@ -444,7 +419,7 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -480,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -593,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -829,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -902,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -928,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if doneCh == nil { t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -970,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1000,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { // Remove the waker before the neighbor cache has the opportunity to send a // notification. neigh.removeWaker(entry.Addr, &w) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -1048,9 +1023,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Fatalf("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", nil, nil) + e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1059,7 +1034,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -1099,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1118,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -1214,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1275,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { config.MaxRandomFactor = 1 nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1303,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1351,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1438,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1466,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Wait() // Process all the requests for a single entry concurrently - clock.advance(typicalLatency) + clock.Advance(typicalLatency) } // All goroutines add in the same order and add more values than can fit in @@ -1498,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1517,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1567,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(config.DelayFirstProbeTime + typicalLatency) + clock.Advance(config.DelayFirstProbeTime + typicalLatency) select { case <-doneCh: default: @@ -1578,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) { // Verify the entry's new link address { e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } @@ -1598,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() @@ -1621,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) if err != nil { t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) @@ -1644,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1662,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config := DefaultNUDConfigurations() config.RetransmitTimer = time.Millisecond // small enough to cause timeout - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1680,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1690,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { // resolved immediately and don't send resolution requests. func TestNeighborCacheStaticResolution(t *testing.T) { config := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 0068cacb8..9a72bec79 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // NeighborEntry describes a neighboring device in the local network. @@ -73,8 +74,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - protocol tcpip.NetworkProtocolNumber + nic *NIC // linkRes provides the functionality to send reachability probes, used in // Neighbor Unreachability Detection. @@ -440,7 +440,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.notifyWakersLocked() } - if e.isRouter && !flags.IsRouter { + if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { // "In those cases where the IsRouter flag changes from TRUE to FALSE as // a result of this update, the node MUST remove that router from the // Default Router List and update the Destination Cache entries for all @@ -448,9 +448,17 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // 7.3.3. This is needed to detect when a node that is used as a router // stops forwarding packets due to being configured as a host." // - RFC 4861 section 7.2.5 - e.nic.mu.Lock() - e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) - e.nic.mu.Unlock() + // + // TODO(gvisor.dev/issue/4085): Remove the special casing we do for IPv6 + // here. + ep, ok := e.nic.networkEndpoints[header.IPv6ProtocolNumber] + if !ok { + panic(fmt.Sprintf("have a neighbor entry for an IPv6 router but no IPv6 network endpoint")) + } + + if ndpEP, ok := ep.(NDPEndpoint); ok { + ndpEP.InvalidateDefaultRouter(e.neigh.Addr) + } } e.isRouter = flags.IsRouter diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 08c9ccd25..a265fff0a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -27,6 +27,8 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -221,8 +223,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return entryTestNetNumber } -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { - clock := newFakeClock() +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { + clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ id: entryTestNICID, @@ -232,17 +234,14 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudDisp: &disp, }, } + nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ + header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), + } rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes) - - // Stub out ndpState to verify modification of default routers. - nic.mu.ndp = ndpState{ - nic: &nic, - defaultRouters: make(map[tcpip.Address]defaultRouterState), - } + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -267,7 +266,7 @@ func TestEntryInitiallyUnknown(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // No probes should have been sent. linkRes.mu.Lock() @@ -300,7 +299,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { } e.mu.Unlock() - clock.advance(time.Hour) + clock.Advance(time.Hour) // No probes should have been sent. linkRes.mu.Lock() @@ -410,7 +409,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { updatedAt := e.neigh.UpdatedAt e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should remain the same during address resolution. wantProbes := []entryTestProbeInfo{ @@ -439,7 +438,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should change after failing address resolution. Timing out after // sending the last probe transitions the entry to Failed. @@ -459,7 +458,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } } - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) wantEvents := []testEntryEventInfo{ { @@ -748,7 +747,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The Incomplete-to-Incomplete state transition is tested here by @@ -816,6 +815,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) + ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) + e.mu.Lock() e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ @@ -829,9 +830,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) } - e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ - invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), - } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -840,8 +839,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { if got, want := e.isRouter, false; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) } - if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { - t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + if ipv6EP.invalidatedRtr != e.neigh.Addr { + t.Errorf("got ipv6EP.invalidatedRtr = %s, want = %s", ipv6EP.invalidatedRtr, e.neigh.Addr) } e.mu.Unlock() @@ -983,7 +982,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1612,7 +1611,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1706,7 +1705,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1989,7 +1988,7 @@ func TestEntryDelayToProbe(t *testing.T) { } e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2069,7 +2068,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2166,7 +2165,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2267,7 +2266,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2344,6 +2343,106 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { nudDisp.mu.Unlock() } +// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario: +// 1. Probe is received +// 2. Entry is created in Stale +// 3. Packet is queued on the entry +// 4. Entry transitions to Delay then Probe +// 5. Probe is sent +func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Probe to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // Probe caused by the Delay-to-Probe transition + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() // Eliminate random factors from ReachableTime computation so the transition @@ -2363,7 +2462,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2403,7 +2502,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2475,7 +2574,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2512,7 +2611,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2582,7 +2681,7 @@ func TestEntryProbeToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2687,7 +2786,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f21066fce..06824843a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -18,23 +18,16 @@ import ( "fmt" "math/rand" "reflect" - "sort" - "strings" "sync/atomic" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) -var ipv4BroadcastAddr = tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 8 * header.IPv4AddressSize, - }, -} +var _ NetworkInterface = (*NIC)(nil) // NIC represents a "network interface card" to which the networking stack is // attached. @@ -45,22 +38,22 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats - neigh *neighborCache + stats NICStats + neigh *neighborCache + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint + + // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. + // + // Must be accessed using atomic operations. + enabled uint32 mu struct { sync.RWMutex - enabled bool - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]uint32 + spoofing bool + promiscuous bool // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - ndp ndpState } } @@ -84,25 +77,6 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } -// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior. -type PrimaryEndpointBehavior int - -const ( - // CanBePrimaryEndpoint indicates the endpoint can be used as a primary - // endpoint for new connections with no local address. This is the - // default when calling NIC.AddAddress. - CanBePrimaryEndpoint PrimaryEndpointBehavior = iota - - // FirstPrimaryEndpoint indicates the endpoint should be the first - // primary endpoint considered. If there are multiple endpoints with - // this behavior, the most recently-added one will be first. - FirstPrimaryEndpoint - - // NeverPrimaryEndpoint indicates the endpoint should never be a - // primary endpoint. - NeverPrimaryEndpoint -) - // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -114,43 +88,43 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - context: ctx, - stats: makeNICStats(), + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) - nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) - nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - nic.mu.ndp.initializeTempAddrState() - - // Register supported packet endpoint protocols. - for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} - } - for _, netProto := range stack.networkProtocols { - nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} - } // Check for Neighbor Unreachability Detection support. - if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 { + var nud NUDHandler + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) nic.neigh = &neighborCache{ nic: nic, state: NewNUDState(stack.nudConfigs, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + + // An interface value that holds a nil pointer but non-nil type is not the + // same as the nil interface. Because of this, nud must only be assignd if + // nic.neigh is non-nil since a nil reference to a neighborCache is not + // valid. + // + // See https://golang.org/doc/faq#nil_error for more information. + nud = nic.neigh + } + + // Register supported packet and network endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.mu.packetEPs[netProto] = []PacketEndpoint{} + } + for _, netProto := range stack.networkProtocols { + netNum := netProto.Number() + nic.mu.packetEPs[netNum] = nil + nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } nic.linkEP.Attach(nic) @@ -158,29 +132,28 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -// enabled returns true if n is enabled. -func (n *NIC) enabled() bool { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - return enabled +// Enabled implements NetworkInterface. +func (n *NIC) Enabled() bool { + return atomic.LoadUint32(&n.enabled) == 1 } -// disable disables n. +// setEnabled sets the enabled status for the NIC. // -// It undoes the work done by enable. -func (n *NIC) disable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if !enabled { - return nil +// Returns true if the enabled status was updated. +func (n *NIC) setEnabled(v bool) bool { + if v { + return atomic.SwapUint32(&n.enabled, 1) == 0 } + return atomic.SwapUint32(&n.enabled, 0) == 1 +} +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() { n.mu.Lock() - err := n.disableLocked() + n.disableLocked() n.mu.Unlock() - return err } // disableLocked disables n. @@ -188,9 +161,9 @@ func (n *NIC) disable() *tcpip.Error { // It undoes the work done by enable. // // n MUST be locked. -func (n *NIC) disableLocked() *tcpip.Error { - if !n.mu.enabled { - return nil +func (n *NIC) disableLocked() { + if !n.setEnabled(false) { + return } // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be @@ -198,33 +171,9 @@ func (n *NIC) disableLocked() *tcpip.Error { // again, and applications may not know that the underlying NIC was ever // disabled. - if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { - n.mu.ndp.stopSolicitingRouters() - n.mu.ndp.cleanupState(false /* hostOnly */) - - // Stop DAD for all the unicast IPv6 endpoints that are in the - // permanentTentative state. - for _, r := range n.mu.endpoints { - if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } - } - - // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } - } - - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - // The address may have already been removed. - if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } + for _, ep := range n.networkEndpoints { + ep.Disable() } - - n.mu.enabled = false - return nil } // enable enables n. @@ -234,150 +183,39 @@ func (n *NIC) disableLocked() *tcpip.Error { // routers if the stack is not operating as a router. If the stack is also // configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { - n.mu.RLock() - enabled := n.mu.enabled - n.mu.RUnlock() - if enabled { - return nil - } - n.mu.Lock() defer n.mu.Unlock() - if n.mu.enabled { + if !n.setEnabled(true) { return nil } - n.mu.enabled = true - - // Create an endpoint to receive broadcast packets on this interface. - if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { + for _, ep := range n.networkEndpoints { + if err := ep.Enable(); err != nil { return err } } - // Join the IPv6 All-Nodes Multicast group if the stack is configured to - // use IPv6. This is required to ensure that this node properly receives - // and responds to the various NDP messages that are destined to the - // all-nodes multicast address. An example is the Neighbor Advertisement - // when we perform Duplicate Address Detection, or Router Advertisement - // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 - // section 4.2 for more information. - // - // Also auto-generate an IPv6 link-local address based on the NIC's - // link address if it is configured to do so. Note, each interface is - // required to have IPv6 link-local unicast address, as per RFC 4291 - // section 2.1. - _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber] - if !ok { - return nil - } - - // Join the All-Nodes multicast group before starting DAD as responses to DAD - // (NDP NS) messages may be sent to the All-Nodes multicast group if the - // source address of the NDP NS is the unspecified address, as per RFC 4861 - // section 7.2.4. - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - - // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent - // state. - // - // Addresses may have aleady completed DAD but in the time since the NIC was - // last enabled, other devices may have acquired the same addresses. - for _, r := range n.mu.endpoints { - addr := r.ep.ID().LocalAddress - if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { - continue - } - - r.setKind(permanentTentative) - if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { - return err - } - } - - // Do not auto-generate an IPv6 link-local address for loopback devices. - if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { - // The valid and preferred lifetime is infinite for the auto-generated - // link-local address. - n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) - } - - // If we are operating as a router, then do not solicit routers since we - // won't process the RAs anyways. - // - // Routers do not process Router Advertisements (RA) the same way a host - // does. That is, routers do not learn from RAs (e.g. on-link prefixes - // and default routers). Therefore, soliciting RAs from other routers on - // a link is unnecessary for routers. - if !n.stack.forwarding { - n.mu.ndp.startSolicitingRouters() - } - return nil } -// remove detaches NIC from the link endpoint, and marks existing referenced -// network endpoints expired. This guarantees no packets between this NIC and -// the network stack. +// remove detaches NIC from the link endpoint and releases network endpoint +// resources. This guarantees no packets between this NIC and the network +// stack. func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() n.disableLocked() - // TODO(b/151378115): come up with a better way to pick an error than the - // first one. - var err *tcpip.Error - - // Forcefully leave multicast groups. - for nid := range n.mu.mcastJoins { - if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { - err = tempErr - } - } - - // Remove permanent and permanentTentative addresses, so no packet goes out. - for nid, ref := range n.mu.endpoints { - switch ref.getKind() { - case permanentTentative, permanent: - if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { - err = tempErr - } - } + for _, ep := range n.networkEndpoints { + ep.Close() } + n.networkEndpoints = nil // Detach from link endpoint, so no packet comes in. n.linkEP.Attach(nil) - - return err -} - -// becomeIPv6Router transitions n into an IPv6 router. -// -// When transitioning into an IPv6 router, host-only state (NDP discovered -// routers, discovered on-link prefixes, and auto-generated addresses) will -// be cleaned up/invalidated and NDP router solicitations will be stopped. -func (n *NIC) becomeIPv6Router() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.cleanupState(true /* hostOnly */) - n.mu.ndp.stopSolicitingRouters() -} - -// becomeIPv6Host transitions n into an IPv6 host. -// -// When transitioning into an IPv6 host, NDP router solicitations will be -// started. -func (n *NIC) becomeIPv6Host() { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.startSolicitingRouters() + return nil } // setPromiscuousMode enables or disables promiscuous mode. @@ -394,7 +232,8 @@ func (n *NIC) isPromiscuousMode() bool { return rv } -func (n *NIC) isLoopback() bool { +// IsLoopback implements NetworkInterface. +func (n *NIC) IsLoopback() bool { return n.linkEP.Capabilities()&CapabilityLoopback != 0 } @@ -405,213 +244,53 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists for the given protocol and remoteAddr. If no non-deprecated -// endpoint exists, the first deprecated endpoint will be returned. -// -// If an IPv6 primary endpoint is requested, Source Address Selection (as -// defined by RFC 6724 section 5) will be performed. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { - if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { - return n.primaryIPv6Endpoint(remoteAddr) - } - +// primaryAddress returns an address that can be used to communicate with +// remoteAddr. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { n.mu.RLock() - defer n.mu.RUnlock() - - var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.mu.primary[protocol] { - if !r.isValidForOutgoingRLocked() { - continue - } - - if !r.deprecated { - if r.tryIncRef() { - // r is not deprecated, so return it immediately. - // - // If we kept track of a deprecated endpoint, decrement its reference - // count since it was incremented when we decided to keep track of it. - if deprecatedEndpoint != nil { - deprecatedEndpoint.decRefLocked() - deprecatedEndpoint = nil - } - - return r - } - } else if deprecatedEndpoint == nil && r.tryIncRef() { - // We prefer an endpoint that is not deprecated, but we keep track of r in - // case n doesn't have any non-deprecated endpoints. - // - // If we end up finding a more preferred endpoint, r's reference count - // will be decremented when such an endpoint is found. - deprecatedEndpoint = r - } - } - - // n doesn't have any valid non-deprecated endpoints, so return - // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated - // endpoints either). - return deprecatedEndpoint -} - -// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC -// 6724 section 5). -type ipv6AddrCandidate struct { - ref *referencedNetworkEndpoint - scope header.IPv6AddressScope -} - -// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - n.mu.RLock() - ref := n.primaryIPv6EndpointRLocked(remoteAddr) + spoofing := n.mu.spoofing n.mu.RUnlock() - return ref -} - -// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address -// Selection (RFC 6724 section 5). -// -// Note, only rules 1-3 and 7 are followed. -// -// remoteAddr must be a valid IPv6 address. -// -// n.mu MUST be read locked. -func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { - primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] - - if len(primaryAddrs) == 0 { - return nil - } - - // Create a candidate set of available addresses we can potentially use as a - // source address. - cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) - for _, r := range primaryAddrs { - // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoingRLocked() { - continue - } - - addr := r.ep.ID().LocalAddress - scope, err := header.ScopeForIPv6Address(addr) - if err != nil { - // Should never happen as we got r from the primary IPv6 endpoint list and - // ScopeForIPv6Address only returns an error if addr is not an IPv6 - // address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) - } - - cs = append(cs, ipv6AddrCandidate{ - ref: r, - scope: scope, - }) - } - - remoteScope, err := header.ScopeForIPv6Address(remoteAddr) - if err != nil { - // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) - } - - // Sort the addresses as per RFC 6724 section 5 rules 1-3. - // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. - sort.Slice(cs, func(i, j int) bool { - sa := cs[i] - sb := cs[j] - - // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.ref.ep.ID().LocalAddress == remoteAddr { - return true - } - if sb.ref.ep.ID().LocalAddress == remoteAddr { - return false - } - - // Prefer appropriate scope as per RFC 6724 section 5 rule 2. - if sa.scope < sb.scope { - return sa.scope >= remoteScope - } else if sb.scope < sa.scope { - return sb.scope < remoteScope - } - - // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. - if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { - // If sa is not deprecated, it is preferred over sb. - return sbDep - } - - // Prefer temporary addresses as per RFC 6724 section 5 rule 7. - if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { - return saTemp - } - - // sa and sb are equal, return the endpoint that is closest to the front of - // the primary endpoint list. - return i < j - }) - - // Return the most preferred address that can have its reference count - // incremented. - for _, c := range cs { - if r := c.ref; r.tryIncRef() { - return r - } - } - - return nil -} - -// hasPermanentAddrLocked returns true if n has a permanent (including currently -// tentative) address, addr. -func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] + ep, ok := n.networkEndpoints[protocol] if !ok { - return false + return nil } - kind := ref.getKind() - - return kind == permanent || kind == permanentTentative + return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } -type getRefBehaviour int +type getAddressBehaviour int const ( // spoofing indicates that the NIC's spoofing flag should be observed when - // getting a NIC's referenced network endpoint. - spoofing getRefBehaviour = iota + // getting a NIC's address endpoint. + spoofing getAddressBehaviour = iota // promiscuous indicates that the NIC's promiscuous flag should be observed - // when getting a NIC's referenced network endpoint. + // when getting a NIC's address endpoint. promiscuous ) -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) +func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. -func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, spoofing) +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) } -// getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. +// getAddressEpOrCreateTemp returns the address endpoint for the given protocol +// and address. // // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getAddressBehaviour) AssignableAddressEndpoint { n.mu.RLock() - var spoofingOrPromiscuous bool switch tempRef { case spoofing: @@ -619,267 +298,54 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous } - - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // An endpoint with this id exists, check if it can be used and return it. - if !ref.isAssignedRLocked(spoofingOrPromiscuous) { - n.mu.RUnlock() - return nil - } - - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the address is found in the NIC's subnets. - createTempEP := spoofingOrPromiscuous - if !createTempEP { - for _, sn := range n.mu.addressRanges { - // Skip the subnet address. - if address == sn.ID() { - continue - } - // For now just skip the broadcast address, until we support it. - // FIXME(b/137608825): Add support for sending/receiving directed - // (subnet) broadcast. - if address == sn.Broadcast() { - continue - } - if sn.Contains(address) { - createTempEP = true - break - } - } - } - n.mu.RUnlock() - - if !createTempEP { - return nil - } - - // Try again with the lock in exclusive mode. If we still can't get the - // endpoint, create a new "temporary" endpoint. It will only exist while - // there's a route through it. - n.mu.Lock() - ref := n.getRefOrCreateTempLocked(protocol, address, peb) - n.mu.Unlock() - return ref + return n.getAddressOrCreateTempInner(protocol, address, spoofingOrPromiscuous, peb) } -/// getRefOrCreateTempLocked returns an existing endpoint for address or creates -/// and returns a temporary endpoint. -func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { - // No need to check the type as we are ok with expired endpoints at this - // point. - if ref.tryIncRef() { - return ref - } - // tryIncRef failing means the endpoint is scheduled to be removed once the - // lock is released. Remove it here so we can create a new (temporary) one. - // The removal logic waiting for the lock handles this case. - n.removeEndpointLocked(ref) +// getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean +// is passed to indicate whether or not we should generate temporary endpoints. +func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { + if ep, ok := n.networkEndpoints[protocol]; ok { + return ep.AcquireAssignedAddress(address, createTemp, peb) } - // Add a new temporary endpoint. - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return nil - } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, temporary, static, false) - return ref + return nil } -// addAddressLocked adds a new protocolAddress to n. -// -// If n already has the address in a non-permanent state, and the kind given is -// permanent, that address will be promoted in place and its properties set to -// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP addresses before adding them. - - // Sanity check. - id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.mu.endpoints[id]; ok { - // Endpoint already exists. - if kind != permanent { - return nil, tcpip.ErrDuplicateAddress - } - switch ref.getKind() { - case permanentTentative, permanent: - // The NIC already have a permanent endpoint with that address. - return nil, tcpip.ErrDuplicateAddress - case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect the new peb, - // configType and deprecated status. - if ref.tryIncRef() { - // TODO(b/147748385): Perform Duplicate Address Detection when promoting - // an IPv6 endpoint to permanent. - ref.setKind(permanent) - ref.deprecated = deprecated - ref.configType = configType - - refs := n.mu.primary[ref.protocol] - for i, r := range refs { - if r == ref { - switch peb { - case CanBePrimaryEndpoint: - return ref, nil - case FirstPrimaryEndpoint: - if i == 0 { - return ref, nil - } - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - case NeverPrimaryEndpoint: - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - return ref, nil - } - } - } - - n.insertPrimaryEndpointLocked(ref, peb) - - return ref, nil - } - // tryIncRef failing means the endpoint is scheduled to be removed once - // the lock is released. Remove it here so we can create a new - // (permanent) one. The removal logic waiting for the lock handles this - // case. - n.removeEndpointLocked(ref) - } - } - - netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] +// addAddress adds a new address to n, so that it starts accepting packets +// targeted at the given address (and network protocol). +func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { + ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { - return nil, tcpip.ErrUnknownProtocol - } - - // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack) - if err != nil { - return nil, err - } - - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) - - // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process if the NIC is - // enabled. If the NIC is not enabled, DAD will be started when the NIC is - // enabled. - if isIPv6Unicast && kind == permanent { - kind = permanentTentative - } - - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, - configType: configType, - deprecated: deprecated, - } - - // Set up cache if link address resolution exists for this protocol. - if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { - ref.linkCache = n.stack - } - } - - // If we are adding an IPv6 unicast address, join the solicited-node - // multicast address. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) - if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { - return nil, err - } + return tcpip.ErrUnknownProtocol } - n.mu.endpoints[id] = ref - - n.insertPrimaryEndpointLocked(ref, peb) - - // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. - if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { - if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { - return nil, err - } + addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + if err == nil { + // We have no need for the address endpoint. + addressEndpoint.DecRef() } - - return ref, nil -} - -// AddAddress adds a new address to n, so that it starts accepting packets -// targeted at the given address (and network protocol). -func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { - // Add the endpoint. - n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) - n.mu.Unlock() - return err } -// AllAddresses returns all addresses (primary and non-primary) associated with +// allPermanentAddresses returns all permanent addresses associated with // this NIC. -func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - - addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) - for nid, ref := range n.mu.endpoints { - // Don't include tentative, expired or temporary endpoints to - // avoid confusion and prevent the caller from using those. - switch ref.getKind() { - case permanentExpired, temporary: - continue +func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { + var addrs []tcpip.ProtocolAddress + for p, ep := range n.networkEndpoints { + for _, a := range ep.PermanentAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: nid.LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, - }) } return addrs } -// PrimaryAddresses returns the primary addresses associated with this NIC. -func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { - n.mu.RLock() - defer n.mu.RUnlock() - +// primaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress - for proto, list := range n.mu.primary { - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints - // to avoid confusion and prevent the caller from using - // those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, - }) + for p, ep := range n.networkEndpoints { + for _, a := range ep.PrimaryAddresses() { + addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } return addrs @@ -891,289 +357,135 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { // address exists. If no non-deprecated address exists, the first deprecated // address will be returned. func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { - n.mu.RLock() - defer n.mu.RUnlock() - - list, ok := n.mu.primary[proto] + ep, ok := n.networkEndpoints[proto] if !ok { return tcpip.AddressWithPrefix{} } - var deprecatedEndpoint *referencedNetworkEndpoint - for _, ref := range list { - // Don't include tentative, expired or tempory endpoints to avoid confusion - // and prevent the caller from using those. - switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - continue - } - - if !ref.deprecated { - return tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - } - } - - if deprecatedEndpoint == nil { - deprecatedEndpoint = ref - } - } - - if deprecatedEndpoint != nil { - return tcpip.AddressWithPrefix{ - Address: deprecatedEndpoint.ep.ID().LocalAddress, - PrefixLen: deprecatedEndpoint.ep.PrefixLen(), - } - } - - return tcpip.AddressWithPrefix{} -} - -// AddAddressRange adds a range of addresses to n, so that it starts accepting -// packets targeted at the given addresses and network protocol. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { - n.mu.Lock() - n.mu.addressRanges = append(n.mu.addressRanges, subnet) - n.mu.Unlock() + return ep.MainAddress() } -// RemoveAddressRange removes the given address range from n. -func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { - n.mu.Lock() - - // Use the same underlying array. - tmp := n.mu.addressRanges[:0] - for _, sub := range n.mu.addressRanges { - if sub != subnet { - tmp = append(tmp, sub) +// removeAddress removes an address from n. +func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { + for _, ep := range n.networkEndpoints { + if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + continue + } else { + return err } } - n.mu.addressRanges = tmp - n.mu.Unlock() + return tcpip.ErrBadLocalAddress } -// AddressRanges returns the Subnets associated with this NIC. -func (n *NIC) AddressRanges() []tcpip.Subnet { - n.mu.RLock() - defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints)) - for nid := range n.mu.endpoints { - sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) - if err != nil { - // This should never happen as the mask has been carefully crafted to - // match the address. - panic("Invalid endpoint subnet: " + err.Error()) - } - sns = append(sns, sn) +func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { + if n.neigh == nil { + return nil, tcpip.ErrNotSupported } - return append(sns, n.mu.addressRanges...) -} -// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required -// by peb. -// -// n MUST be locked. -func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { - switch peb { - case CanBePrimaryEndpoint: - n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) - case FirstPrimaryEndpoint: - n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) - } + return n.neigh.entries(), nil } -func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := *r.ep.ID() - - // Nothing to do if the reference has already been replaced with a different - // one. This happens in the case where 1) this endpoint's ref count hit zero - // and was waiting (on the lock) to be removed and 2) the same address was - // re-added in the meantime by removing this endpoint from the list and - // adding a new one. - if n.mu.endpoints[id] != r { +func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { + if n.neigh == nil { return } - if r.getKind() == permanent { - panic("Reference count dropped to zero before being removed") - } + n.neigh.removeWaker(addr, w) +} - delete(n.mu.endpoints, id) - refs := n.mu.primary[r.protocol] - for i, ref := range refs { - if ref == r { - n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) - refs[len(refs)-1] = nil - break - } +func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } - r.ep.Close() -} - -func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { - n.mu.Lock() - n.removeEndpointLocked(r) - n.mu.Unlock() + n.neigh.addStaticEntry(addr, linkAddress) + return nil } -func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadLocalAddress - } - - kind := r.getKind() - if kind != permanent && kind != permanentTentative { - return tcpip.ErrBadLocalAddress +func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } - switch r.protocol { - case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) - default: - r.expireLocked() - return nil + if !n.neigh.removeEntry(addr) { + return tcpip.ErrBadAddress } + return nil } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { - addr := r.addrWithPrefix() - - isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) - - if isIPv6Unicast { - n.mu.ndp.stopDuplicateAddressDetection(addr.Address) - - // If we are removing an address generated via SLAAC, cleanup - // its SLAAC resources and notify the integrator. - switch r.configType { - case slaac: - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - case slaacTemp: - n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) - } - } - - r.expireLocked() - - // At this point the endpoint is deleted. - - // If we are removing an IPv6 unicast address, leave the solicited-node - // multicast address. - // - // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node - // multicast group may be left by user action. - if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr.Address) - if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { - return err - } +func (n *NIC) clearNeighbors() *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported } + n.neigh.clear() return nil } -// RemoveAddress removes an address from n. -func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - return n.removePermanentAddressLocked(addr) -} - // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.joinGroupLocked(protocol, addr) -} - -// joinGroupLocked adds a new endpoint for the given multicast address, if none -// exists yet. Otherwise it just increments its count. n MUST be locked before -// joinGroupLocked is called. -func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { // TODO(b/143102137): When implementing MLD, make sure MLD packets are // not sent unless a valid link-local address is available for use on n // as an MLD packet's source address must be a link-local address as // outlined in RFC 3810 section 5. - id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - if joins == 0 { - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - return tcpip.ErrUnknownProtocol - } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported } - n.mu.mcastJoins[id] = joins + 1 - return nil + + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + _, err := gep.JoinGroup(addr) + return err } // leaveGroup decrements the count for the given multicast address, and when it // reaches zero removes the endpoint for this address. -func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - return n.leaveGroupLocked(addr, false /* force */) -} +func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + ep, ok := n.networkEndpoints[protocol] + if !ok { + return tcpip.ErrNotSupported + } -// leaveGroupLocked decrements the count for the given multicast address, and -// when it reaches zero removes the endpoint for this address. n MUST be locked -// before leaveGroupLocked is called. -// -// If force is true, then the count for the multicast addres is ignored and the -// endpoint will be removed immediately. -func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { - id := NetworkEndpointID{addr} - joins, ok := n.mu.mcastJoins[id] + gep, ok := ep.(GroupAddressableEndpoint) if !ok { - // There are no joins with this address on this NIC. - return tcpip.ErrBadLocalAddress + return tcpip.ErrNotSupported } - joins-- - if force || joins == 0 { - // There are no outstanding joins or we are forced to leave, clean up. - delete(n.mu.mcastJoins, id) - return n.removePermanentAddressLocked(addr) + if _, err := gep.LeaveGroup(addr); err != nil { + return err } - n.mu.mcastJoins[id] = joins return nil } // isInGroup returns true if n has joined the multicast group addr. func (n *NIC) isInGroup(addr tcpip.Address) bool { - n.mu.RLock() - joins := n.mu.mcastJoins[NetworkEndpointID{addr}] - n.mu.RUnlock() + for _, ep := range n.networkEndpoints { + gep, ok := ep.(GroupAddressableEndpoint) + if !ok { + continue + } - return joins != 0 + if gep.IsInGroup(addr) { + return true + } + } + + return false } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) +func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { + r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - - ref.ep.HandlePacket(&r, pkt) - ref.decRef() + addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + addressEndpoint.DecRef() } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -1184,7 +496,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // the ownership of the items is not retained by the caller. func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() - enabled := n.mu.enabled + enabled := n.Enabled() // If the NIC is not yet enabled, don't receive any packets. if !enabled { n.mu.RUnlock() @@ -1210,9 +522,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp local = n.linkEP.LinkAddress() } - // Are any packet sockets listening for this network protocol? + // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Add any other packet sockets that maybe listening for all protocols. + // Add any other packet type sockets that may be listening for all protocols. packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { @@ -1233,37 +545,42 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } if hasTransportHdr { + pkt.TransportProtocolNumber = transProtoNum // Parse the transport header if present. if state, ok := n.stack.transportProtocols[transProtoNum]; ok { state.proto.Parse(pkt) } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader) + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { - // The source address is one of our own, so we never should have gotten a - // packet like this unless handleLocal is false. Loopback also calls this - // function even though the packets didn't come from the physical interface - // so don't drop those. - n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() - return + if n.stack.handleLocal && !n.IsLoopback() { + if r := n.getAddress(protocol, src); r != nil { + r.DecRef() + + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } } - // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if !n.IsLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. + n.stack.stats.IP.IPTablesPreroutingDropped.Increment() return } } - if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) + if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { + n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) return } @@ -1271,7 +588,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1279,25 +596,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.ref.nic - n.mu.RLock() - ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() - n.mu.RUnlock() - if ok { - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote - r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, pkt) - ref.decRef() - r.Release() - return + n := r.nic + if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { + if n.isValidForOutgoing(addressEndpoint) { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote + r.RemoteAddress = src + // TODO(b/123449044): Update the source NIC as well. + addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + addressEndpoint.DecRef() + r.Release() + return + } + + addressEndpoint.DecRef() } // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) @@ -1341,24 +659,19 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this - // by having lower layers explicity write each header instead of just - // pkt.Header. - - // pkt may have set its NetworkHeader and TransportHeader. If we're - // forwarding, we'll have to copy them into pkt.Header. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) - if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) - } - if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) - } - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) + + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } @@ -1369,11 +682,11 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return + return TransportPacketProtocolUnreachable } transProto := state.proto @@ -1383,55 +696,58 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. - if pkt.TransportHeader == nil { + if pkt.TransportHeader().View().IsEmpty() { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() - return + // We consider a malformed transport packet handled because there is + // nothing the caller can do. + return TransportPacketHandled } - pkt.TransportHeader = transHeader - pkt.Data.TrimFront(len(pkt.TransportHeader)) - } else { - // This is either a bad packet or was re-assembled from fragments. - transProto.Parse(pkt) + } else if !transProto.Parse(pkt) { + n.stack.stats.MalformedRcvdPackets.Increment() + return TransportPacketHandled } } - if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() - return + return TransportPacketHandled } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} if n.stack.demux.deliverPacket(r, protocol, pkt, id) { - return + return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { if state.defaultHandler(r, id, pkt) { - return + return TransportPacketHandled } } - // We could not find an appropriate destination for this packet, so - // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { + // We could not find an appropriate destination for this packet so + // give the protocol specific error handler a chance to handle it. + // If it doesn't handle it then we should do so. + switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() + return TransportPacketHandled + case UnknownDestinationPacketUnhandled: + return TransportPacketDestinationPortUnreachable + case UnknownDestinationPacketHandled: + return TransportPacketHandled + default: + panic(fmt.Sprintf("unrecognized result from HandleUnknownDestinationPacket = %d", res)) } } @@ -1464,96 +780,23 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } } -// ID returns the identifier of n. +// ID implements NetworkInterface. func (n *NIC) ID() tcpip.NICID { return n.id } -// Name returns the name of n. +// Name implements NetworkInterface. func (n *NIC) Name() string { return n.name } -// Stack returns the instance of the Stack that owns this NIC. -func (n *NIC) Stack() *Stack { - return n.stack -} - -// LinkEndpoint returns the link endpoint of n. +// LinkEndpoint implements NetworkInterface. func (n *NIC) LinkEndpoint() LinkEndpoint { return n.linkEP } -// isAddrTentative returns true if addr is tentative on n. -// -// Note that if addr is not associated with n, then this function will return -// false. It will only return true if the address is associated with the NIC -// AND it is tentative. -func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - n.mu.RLock() - defer n.mu.RUnlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return false - } - - return ref.getKind() == permanentTentative -} - -// dupTentativeAddrDetected attempts to inform n that a tentative addr is a -// duplicate on a link. -// -// dupTentativeAddrDetected will remove the tentative address if it exists. If -// the address was generated via SLAAC, an attempt will be made to generate a -// new address. -func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { - n.mu.Lock() - defer n.mu.Unlock() - - ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] - if !ok { - return tcpip.ErrBadAddress - } - - if ref.getKind() != permanentTentative { - return tcpip.ErrInvalidEndpointState - } - - // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a - // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { - return err - } - - prefix := ref.addrWithPrefix().Subnet() - - switch ref.configType { - case slaac: - n.mu.ndp.regenerateSLAACAddr(prefix) - case slaacTemp: - // Do not reset the generation attempts counter for the prefix as the - // temporary address is being regenerated in response to a DAD conflict. - n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) - } - - return nil -} - -// setNDPConfigs sets the NDP configurations for n. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (n *NIC) setNDPConfigs(c NDPConfigurations) { - c.validate() - - n.mu.Lock() - n.mu.ndp.configs = c - n.mu.Unlock() -} - -// NUDConfigs gets the NUD configurations for n. -func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { +// nudConfigs gets the NUD configurations for n. +func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) { if n.neigh == nil { return NUDConfigurations{}, tcpip.ErrNotSupported } @@ -1573,49 +816,6 @@ func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { return nil } -// handleNDPRA handles an NDP Router Advertisement message that arrived on n. -func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { - n.mu.Lock() - defer n.mu.Unlock() - - n.mu.ndp.handleRA(ip, ra) -} - -type networkEndpointKind int32 - -const ( - // A permanentTentative endpoint is a permanent address that is not yet - // considered to be fully bound to an interface in the traditional - // sense. That is, the address is associated with a NIC, but packets - // destined to the address MUST NOT be accepted and MUST be silently - // dropped, and the address MUST NOT be used as a source address for - // outgoing packets. For IPv6, addresses will be of this kind until - // NDP's Duplicate Address Detection has resolved, or be deleted if - // the process results in detecting a duplicate address. - permanentTentative networkEndpointKind = iota - - // A permanent endpoint is created by adding a permanent address (vs. a - // temporary one) to the NIC. Its reference count is biased by 1 to avoid - // removal when no route holds a reference to it. It is removed by explicitly - // removing the permanent address from the NIC. - permanent - - // An expired permanent endpoint is a permanent endpoint that had its address - // removed from the NIC, and it is waiting to be removed once no more routes - // hold a reference to it. This is achieved by decreasing its reference count - // by 1. If its address is re-added before the endpoint is removed, its type - // changes back to permanent and its reference count increases by 1 again. - permanentExpired - - // A temporary endpoint is created for spoofing outgoing packets, or when in - // promiscuous mode and accepting incoming packets that don't match any - // permanent endpoint. Its reference count is not biased by 1 and the - // endpoint is removed immediately when no more route holds a reference to - // it. A temporary endpoint can be promoted to permanent if its address - // is added permanently. - temporary -) - func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1646,147 +846,12 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } -type networkEndpointConfigType int32 - -const ( - // A statically configured endpoint is an address that was added by - // some user-specified action (adding an explicit address, joining a - // multicast group). - static networkEndpointConfigType = iota - - // A SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4862 section 5.5.3. - slaac - - // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by - // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are - // not expected to be valid (or preferred) forever; hence the term temporary. - slaacTemp -) - -type referencedNetworkEndpoint struct { - ep NetworkEndpoint - nic *NIC - protocol tcpip.NetworkProtocolNumber - - // linkCache is set if link address resolution is enabled for this - // protocol. Set to nil otherwise. - linkCache LinkAddressCache - - // refs is counting references held for this endpoint. When refs hits zero it - // triggers the automatic removal of the endpoint from the NIC. - refs int32 - - // networkEndpointKind must only be accessed using {get,set}Kind(). - kind networkEndpointKind - - // configType is the method that was used to configure this endpoint. - // This must never change except during endpoint creation and promotion to - // permanent. - configType networkEndpointConfigType - - // deprecated indicates whether or not the endpoint should be considered - // deprecated. That is, when deprecated is true, other endpoints that are not - // deprecated should be preferred. - deprecated bool -} - -func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { - return tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } -} - -func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { - return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) -} - -func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { - atomic.StoreInt32((*int32)(&r.kind), int32(kind)) -} - // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address) // has been removed) unless the NIC is in spoofing mode, or temporary. -func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - r.nic.mu.RLock() - defer r.nic.mu.RUnlock() - - return r.isValidForOutgoingRLocked() -} - -// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires -// r.nic.mu to be read locked. -func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - if !r.nic.mu.enabled { - return false - } - - return r.isAssignedRLocked(r.nic.mu.spoofing) -} - -// isAssignedRLocked returns true if r is considered to be assigned to the NIC. -// -// r.nic.mu must be read locked. -func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { - switch r.getKind() { - case permanentTentative: - return false - case permanentExpired: - return spoofingOrPromiscuous - default: - return true - } -} - -// expireLocked decrements the reference count and marks the permanent endpoint -// as expired. -func (r *referencedNetworkEndpoint) expireLocked() { - r.setKind(permanentExpired) - r.decRefLocked() -} - -// decRef decrements the ref count and cleans up the endpoint once it reaches -// zero. -func (r *referencedNetworkEndpoint) decRef() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpoint(r) - } -} - -// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { - if atomic.AddInt32(&r.refs, -1) == 0 { - r.nic.removeEndpointLocked(r) - } -} - -// incRef increments the ref count. It must only be called when the caller is -// known to be holding a reference to the endpoint, otherwise tryIncRef should -// be used. -func (r *referencedNetworkEndpoint) incRef() { - atomic.AddInt32(&r.refs, 1) -} - -// tryIncRef attempts to increment the ref count from n to n+1, but only if n is -// not zero. That is, it will increment the count if the endpoint is still -// alive, and do nothing if it has already been clean up. -func (r *referencedNetworkEndpoint) tryIncRef() bool { - for { - v := atomic.LoadInt32(&r.refs) - if v == 0 { - return false - } - - if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { - return true - } - } -} - -// stack returns the Stack instance that owns the underlying endpoint. -func (r *referencedNetworkEndpoint) stack() *Stack { - return r.nic.stack +func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + return n.Enabled() && ep.IsAssigned(spoofing) } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index a70792b50..fdd49b77f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,98 +15,40 @@ package stack import ( - "math" "testing" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) -var _ LinkEndpoint = (*testLinkEndpoint)(nil) +var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) +var _ NDPEndpoint = (*testIPv6Endpoint)(nil) -// A LinkEndpoint that throws away outgoing packets. +// An IPv6 NetworkEndpoint that throws away outgoing packets. // -// We use this instead of the channel endpoint as the channel package depends on +// We use this instead of ipv6.endpoint because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. -type testLinkEndpoint struct { - dispatcher NetworkDispatcher -} - -// Attach implements LinkEndpoint.Attach. -func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { - e.dispatcher = dispatcher -} - -// IsAttached implements LinkEndpoint.IsAttached. -func (e *testLinkEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements LinkEndpoint.MTU. -func (*testLinkEndpoint) MTU() uint32 { - return math.MaxUint16 -} - -// Capabilities implements LinkEndpoint.Capabilities. -func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { - return CapabilityResolutionRequired -} +type testIPv6Endpoint struct { + AddressableEndpointState -// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. -func (*testLinkEndpoint) MaxHeaderLength() uint16 { - return 0 -} + nicID tcpip.NICID + linkEP LinkEndpoint + protocol *testIPv6Protocol -// LinkAddress returns the link address of this endpoint. -func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { - return "" + invalidatedRtr tcpip.Address } -// Wait implements LinkEndpoint.Wait. -func (*testLinkEndpoint) Wait() {} - -// WritePacket implements LinkEndpoint.WritePacket. -func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { +func (*testIPv6Endpoint) Enable() *tcpip.Error { return nil } -// WritePackets implements LinkEndpoint.WritePackets. -func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Our tests don't use this so we don't support it. - return 0, tcpip.ErrNotSupported +func (*testIPv6Endpoint) Enabled() bool { + return true } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // Our tests don't use this so we don't support it. - return tcpip.ErrNotSupported -} - -// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. -func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { - panic("not implemented") -} - -// AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { - panic("not implemented") -} - -var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) - -// An IPv6 NetworkEndpoint that throws away outgoing packets. -// -// We use this instead of ipv6.endpoint because the ipv6 package depends on -// the stack package which this test lives in, causing a cyclic dependency. -type testIPv6Endpoint struct { - nicID tcpip.NICID - id NetworkEndpointID - prefixLen int - linkEP LinkEndpoint - protocol *testIPv6Protocol -} +func (*testIPv6Endpoint) Disable() {} // DefaultTTL implements NetworkEndpoint.DefaultTTL. func (*testIPv6Endpoint) DefaultTTL() uint8 { @@ -118,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 { return e.linkEP.MTU() - header.IPv6MinimumSize } -// Capabilities implements NetworkEndpoint.Capabilities. -func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { - return e.linkEP.Capabilities() -} - // MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize @@ -146,33 +83,24 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip return tcpip.ErrNotSupported } -// ID implements NetworkEndpoint.ID. -func (e *testIPv6Endpoint) ID() *NetworkEndpointID { - return &e.id -} - -// PrefixLen implements NetworkEndpoint.PrefixLen. -func (e *testIPv6Endpoint) PrefixLen() int { - return e.prefixLen -} - -// NICID implements NetworkEndpoint.NICID. -func (e *testIPv6Endpoint) NICID() tcpip.NICID { - return e.nicID -} - // HandlePacket implements NetworkEndpoint.HandlePacket. func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { } // Close implements NetworkEndpoint.Close. -func (*testIPv6Endpoint) Close() {} +func (e *testIPv6Endpoint) Close() { + e.AddressableEndpointState.Cleanup() +} // NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return header.IPv6ProtocolNumber } +func (e *testIPv6Endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { + e.invalidatedRtr = rtr +} + var _ NetworkProtocol = (*testIPv6Protocol)(nil) // An IPv6 NetworkProtocol that supports the bare minimum to make a stack @@ -204,23 +132,23 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { - return &testIPv6Endpoint{ - nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - protocol: p, - }, nil +func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint { + e := &testIPv6Endpoint{ + nicID: nic.ID(), + linkEP: nic.LinkEndpoint(), + protocol: p, + } + e.AddressableEndpointState.Init(e) + return e } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return nil } @@ -255,38 +183,6 @@ func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAdd return "", false } -// Test the race condition where a NIC is removed and an RS timer fires at the -// same time. -func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { - const ( - nicID = 1 - - maxRtrSolicitations = 5 - ) - - e := testLinkEndpoint{} - s := New(Options{ - NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, - NDPConfigs: NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: minimumRtrSolicitationInterval, - }, - }) - - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) - } - - s.mu.Lock() - // Wait for the router solicitation timer to fire and block trying to obtain - // the stack lock when doing link address resolution. - time.Sleep(minimumRtrSolicitationInterval * 2) - if err := s.removeNICLocked(nicID); err != nil { - t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) - } - s.mu.Unlock() -} - func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. @@ -311,7 +207,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index f848d50ad..e1ec15487 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 2494ee610..8cffb9fc6 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -60,7 +60,8 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + UseNeighborCache: true, }) // No NIC with ID 1 yet. @@ -84,7 +85,8 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -108,7 +110,8 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -134,8 +137,9 @@ func TestDefaultNUDConfigurations(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The networking // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -188,8 +192,9 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -244,8 +249,9 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -323,8 +329,9 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -384,8 +391,9 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -435,8 +443,9 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -486,8 +495,9 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -537,8 +547,9 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -588,8 +599,9 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) { // A neighbor cache is required to store NUDConfigurations. The // networking stack will only allocate neighbor caches if a protocol // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 5d6865e35..a7d9d59fa 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,16 +14,43 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + // A PacketBuffer contains all the data of a network packet. // // As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. type PacketBuffer struct { _ sync.NoCopy @@ -31,36 +58,38 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. + // + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. Data buffer.VectorisedView - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. - Header buffer.Prependable + // headers stores metadata about each header. + headers [numHeaderType]headerInfo - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. - // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable + + // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() + // returns false. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber + + // TransportProtocol is only valid if it is non zero. + // TODO(gvisor.dev/issue/3810): This and the network protocol number should + // be moved into the headerinfo. This should resolve the validity issue. + TransportProtocolNumber tcpip.TransportProtocolNumber // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. @@ -72,9 +101,8 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. - EgressRoute *Route - GSOOptions *GSO - NetworkProtocolNumber tcpip.NetworkProtocolNumber + EgressRoute *Route + GSOOptions *GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. @@ -85,25 +113,194 @@ type PacketBuffer struct { PktType tcpip.PacketType } -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. // -// FIXME(b/153685824): Data gets copied but not other header references. +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - Header: pk.Header, - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + newPk := &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + } + return newPk +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} + +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) } + return append(v, h.pk.Data.ToView()...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 8604c4259..16f854e1f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -125,6 +127,26 @@ type PacketEndpoint interface { HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } +// UnknownDestinationPacketDisposition enumerates the possible return vaues from +// HandleUnknownDestinationPacket(). +type UnknownDestinationPacketDisposition int + +const ( + // UnknownDestinationPacketMalformed denotes that the packet was malformed + // and no further processing should be attempted other than updating + // statistics. + UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota + + // UnknownDestinationPacketUnhandled tells the caller that the packet was + // well formed but that the issue was not handled and the stack should take + // the default action. + UnknownDestinationPacketUnhandled + + // UnknownDestinationPacketHandled tells the caller that it should do + // no further processing. + UnknownDestinationPacketHandled +) + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { @@ -132,10 +154,10 @@ type TransportProtocol interface { Number() tcpip.TransportProtocolNumber // NewEndpoint creates a new endpoint of the transport protocol. - NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // NewRawEndpoint creates a new raw endpoint of the transport protocol. - NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) // MinimumPacketSize returns the minimum valid packet size of this // transport protocol. The stack automatically drops any packets smaller @@ -147,24 +169,22 @@ type TransportProtocol interface { ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this - // protocol but that don't match any existing endpoint. For example, - // it is targeted at a port that have no listeners. - // - // The return value indicates whether the packet was well-formed (for - // stats purposes only). + // protocol that don't match any existing endpoint. For example, + // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // the issue. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -179,6 +199,25 @@ type TransportProtocol interface { Parse(pkt *PacketBuffer) (ok bool) } +// TransportPacketDisposition is the result from attempting to deliver a packet +// to the transport layer. +type TransportPacketDisposition int + +const ( + // TransportPacketHandled indicates that a transport packet was handled by the + // transport layer and callers need not take any further action. + TransportPacketHandled TransportPacketDisposition = iota + + // TransportPacketProtocolUnreachable indicates that the transport + // protocol requested in the packet is not supported. + TransportPacketProtocolUnreachable + + // TransportPacketDestinationPortUnreachable indicates that there weren't any + // listeners interested in the packet and the transport protocol has no means + // to notify the sender. + TransportPacketDestinationPortUnreachable +) + // TransportDispatcher contains the methods used by the network stack to deliver // packets to the appropriate transport endpoint after it has been handled by // the network layer. @@ -189,7 +228,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -226,9 +265,257 @@ type NetworkHeaderParams struct { TOS uint8 } +// GroupAddressableEndpoint is an endpoint that supports group addressing. +// +// An endpoint is considered to support group addressing when one or more +// endpoints may associate themselves with the same identifier (group address). +type GroupAddressableEndpoint interface { + // JoinGroup joins the spcified group. + // + // Returns true if the group was newly joined. + JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + + // LeaveGroup attempts to leave the specified group. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. + LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + + // IsInGroup returns true if the endpoint is a member of the specified group. + IsInGroup(group tcpip.Address) bool +} + +// PrimaryEndpointBehavior is an enumeration of an AddressEndpoint's primary +// behavior. +type PrimaryEndpointBehavior int + +const ( + // CanBePrimaryEndpoint indicates the endpoint can be used as a primary + // endpoint for new connections with no local address. This is the + // default when calling NIC.AddAddress. + CanBePrimaryEndpoint PrimaryEndpointBehavior = iota + + // FirstPrimaryEndpoint indicates the endpoint should be the first + // primary endpoint considered. If there are multiple endpoints with + // this behavior, they are ordered by recency. + FirstPrimaryEndpoint + + // NeverPrimaryEndpoint indicates the endpoint should never be a + // primary endpoint. + NeverPrimaryEndpoint +) + +// AddressConfigType is the method used to add an address. +type AddressConfigType int + +const ( + // AddressConfigStatic is a statically configured address endpoint that was + // added by some user-specified action (adding an explicit address, joining a + // multicast group). + AddressConfigStatic AddressConfigType = iota + + // AddressConfigSlaac is an address endpoint added by SLAAC, as per RFC 4862 + // section 5.5.3. + AddressConfigSlaac + + // AddressConfigSlaacTemp is a temporary address endpoint added by SLAAC as + // per RFC 4941. Temporary SLAAC addresses are short-lived and are not + // to be valid (or preferred) forever; hence the term temporary. + AddressConfigSlaacTemp +) + +// AssignableAddressEndpoint is a reference counted address endpoint that may be +// assigned to a NetworkEndpoint. +type AssignableAddressEndpoint interface { + // NetworkEndpoint returns the NetworkEndpoint the receiver is associated + // with. + NetworkEndpoint() NetworkEndpoint + + // AddressWithPrefix returns the endpoint's address. + AddressWithPrefix() tcpip.AddressWithPrefix + + // IsAssigned returns whether or not the endpoint is considered bound + // to its NetworkEndpoint. + IsAssigned(allowExpired bool) bool + + // IncRef increments this endpoint's reference count. + // + // Returns true if it was successfully incremented. If it returns false, then + // the endpoint is considered expired and should no longer be used. + IncRef() bool + + // DecRef decrements this endpoint's reference count. + DecRef() +} + +// AddressEndpoint is an endpoint representing an address assigned to an +// AddressableEndpoint. +type AddressEndpoint interface { + AssignableAddressEndpoint + + // GetKind returns the address kind for this endpoint. + GetKind() AddressKind + + // SetKind sets the address kind for this endpoint. + SetKind(AddressKind) + + // ConfigType returns the method used to add the address. + ConfigType() AddressConfigType + + // Deprecated returns whether or not this endpoint is deprecated. + Deprecated() bool + + // SetDeprecated sets this endpoint's deprecated status. + SetDeprecated(bool) +} + +// AddressKind is the kind of of an address. +// +// See the values of AddressKind for more details. +type AddressKind int + +const ( + // PermanentTentative is a permanent address endpoint that is not yet + // considered to be fully bound to an interface in the traditional + // sense. That is, the address is associated with a NIC, but packets + // destined to the address MUST NOT be accepted and MUST be silently + // dropped, and the address MUST NOT be used as a source address for + // outgoing packets. For IPv6, addresses are of this kind until NDP's + // Duplicate Address Detection (DAD) resolves. If DAD fails, the address + // is removed. + PermanentTentative AddressKind = iota + + // Permanent is a permanent endpoint (vs. a temporary one) assigned to the + // NIC. Its reference count is biased by 1 to avoid removal when no route + // holds a reference to it. It is removed by explicitly removing the address + // from the NIC. + Permanent + + // PermanentExpired is a permanent endpoint that had its address removed from + // the NIC, and it is waiting to be removed once no references to it are held. + // + // If the address is re-added before the endpoint is removed, its type + // changes back to Permanent. + PermanentExpired + + // Temporary is an endpoint, created on a one-off basis to temporarily + // consider the NIC bound an an address that it is not explictiy bound to + // (such as a permanent address). Its reference count must not be biased by 1 + // so that the address is removed immediately when references to it are no + // longer held. + // + // A temporary endpoint may be promoted to permanent if the address is added + // permanently. + Temporary +) + +// IsPermanent returns true if the AddressKind represents a permanent address. +func (k AddressKind) IsPermanent() bool { + switch k { + case Permanent, PermanentTentative: + return true + case Temporary, PermanentExpired: + return false + default: + panic(fmt.Sprintf("unrecognized address kind = %d", k)) + } +} + +// AddressableEndpoint is an endpoint that supports addressing. +// +// An endpoint is considered to support addressing when the endpoint may +// associate itself with an identifier (address). +type AddressableEndpoint interface { + // AddAndAcquirePermanentAddress adds the passed permanent address. + // + // Returns tcpip.ErrDuplicateAddress if the address exists. + // + // Acquires and returns the AddressEndpoint for the added address. + AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb PrimaryEndpointBehavior, configType AddressConfigType, deprecated bool) (AddressEndpoint, *tcpip.Error) + + // RemovePermanentAddress removes the passed address if it is a permanent + // address. + // + // Returns tcpip.ErrBadLocalAddress if the endpoint does not have the passed + // permanent address. + RemovePermanentAddress(addr tcpip.Address) *tcpip.Error + + // MainAddress returns the endpoint's primary permanent address. + MainAddress() tcpip.AddressWithPrefix + + // AcquireAssignedAddress returns an address endpoint for the passed address + // that is considered bound to the endpoint, optionally creating a temporary + // endpoint if requested and no existing address exists. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if the specified address is not local to this endpoint. + AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint + + // AcquireOutgoingPrimaryAddress returns a primary address that may be used as + // a source address when sending packets to the passed remote address. + // + // If allowExpired is true, expired addresses may be returned. + // + // The returned endpoint's reference count is incremented. + // + // Returns nil if a primary address is not available. + AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint + + // PrimaryAddresses returns the primary addresses. + PrimaryAddresses() []tcpip.AddressWithPrefix + + // PermanentAddresses returns all the permanent addresses. + PermanentAddresses() []tcpip.AddressWithPrefix +} + +// NDPEndpoint is a network endpoint that supports NDP. +type NDPEndpoint interface { + NetworkEndpoint + + // InvalidateDefaultRouter invalidates a default router discovered through + // NDP. + InvalidateDefaultRouter(tcpip.Address) +} + +// NetworkInterface is a network interface. +type NetworkInterface interface { + // ID returns the interface's ID. + ID() tcpip.NICID + + // IsLoopback returns true if the interface is a loopback interface. + IsLoopback() bool + + // Name returns the name of the interface. + // + // May return an empty string if the interface is not configured with a name. + Name() string + + // Enabled returns true if the interface is enabled. + Enabled() bool + + // LinkEndpoint returns the link endpoint backing the interface. + LinkEndpoint() LinkEndpoint +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { + AddressableEndpoint + + // Enable enables the endpoint. + // + // Must only be called when the stack is in a state that allows the endpoint + // to send and receive packets. + // + // Returns tcpip.ErrNotPermitted if the endpoint cannot be enabled. + Enable() *tcpip.Error + + // Enabled returns true if the endpoint is enabled. + Enabled() bool + + // Disable disables the endpoint. + Disable() + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) // for this endpoint. DefaultTTL() uint8 @@ -238,10 +525,6 @@ type NetworkEndpoint interface { // minus the network endpoint max header length. MTU() uint32 - // Capabilities returns the set of capabilities supported by the - // underlying link-layer endpoint. - Capabilities() LinkEndpointCapabilities - // MaxHeaderLength returns the maximum size the network (and lower // level layers combined) headers can have. Higher levels use this // information to reserve space in the front of the packets they're @@ -249,8 +532,8 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have already - // been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and @@ -262,15 +545,6 @@ type NetworkEndpoint interface { // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error - // ID returns the network protocol endpoint ID. - ID() *NetworkEndpointID - - // PrefixLen returns the network endpoint's subnet prefix length in bits. - PrefixLen() int - - // NICID returns the id of the NIC this endpoint belongs to. - NICID() tcpip.NICID - // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. It sets pkt.NetworkHeader. // @@ -285,6 +559,17 @@ type NetworkEndpoint interface { NetworkProtocolNumber() tcpip.NetworkProtocolNumber } +// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. +type ForwardingNetworkProtocol interface { + NetworkProtocol + + // Forwarding returns the forwarding configuration. + Forwarding() bool + + // SetForwarding sets the forwarding configuration. + SetForwarding(bool) +} + // NetworkProtocol is the interface that needs to be implemented by network // protocols (e.g., ipv4, ipv6) that want to be part of the networking stack. type NetworkProtocol interface { @@ -304,17 +589,17 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -433,8 +718,8 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // - // Attach will be called with a nil dispatcher if the receiver's associated - // NIC is being removed. + // Attach is called with a nil dispatcher when the endpoint's NIC is being + // removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the @@ -494,7 +779,7 @@ type LinkAddressResolver interface { ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) // LinkAddressProtocol returns the network protocol of the - // addresses this this resolver can resolve. + // addresses this resolver can resolve. LinkAddressProtocol() tcpip.NetworkProtocolNumber } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 91e0110f1..5ade3c832 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -42,21 +42,27 @@ type Route struct { // NetProto is the network-layer protocol. NetProto tcpip.NetworkProtocolNumber - // ref a reference to the network endpoint through which the route - // starts. - ref *referencedNetworkEndpoint - // Loop controls where WritePacket should send packets. Loop PacketLooping - // directedBroadcast indicates whether this route is sending a directed - // broadcast packet. - directedBroadcast bool + // nic is the NIC the route goes through. + nic *NIC + + // addressEndpoint is the local address this route is associated with. + addressEndpoint AssignableAddressEndpoint + + // linkCache is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkCache LinkAddressCache + + // linkRes is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkRes LinkAddressResolver } // makeRoute initializes a new route. It takes ownership of the provided -// reference to a network endpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, handleLocal, multicastLoop bool) Route { +// AssignableAddressEndpoint. +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { loop := PacketOut if handleLocal && localAddr != "" && remoteAddr == localAddr { loop = PacketLoop @@ -66,29 +72,40 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip loop |= PacketLoop } - return Route{ + linkEP := nic.LinkEndpoint() + r := Route{ NetProto: netProto, LocalAddress: localAddr, - LocalLinkAddress: localLinkAddr, + LocalLinkAddress: linkEP.LinkAddress(), RemoteAddress: remoteAddr, - ref: ref, + addressEndpoint: addressEndpoint, + nic: nic, Loop: loop, } + + if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok { + r.linkRes = linkRes + r.linkCache = nic.stack + } + } + + return r } // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.ref.ep.NICID() + return r.nic.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.ref.ep.MaxHeaderLength() + return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.ref.nic.stack.Stats() + return r.nic.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -99,17 +116,23 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.ref.ep.Capabilities() + return r.nic.LinkEndpoint().Capabilities() } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.ref.ep.(GSOEndpoint); ok { + if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -135,7 +158,17 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { } nextAddr = r.RemoteAddress } - linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + + if neigh := r.nic.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + if err != nil { + return ch, err + } + r.RemoteLinkAddress = entry.LinkAddr + return nil, nil + } + + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { return ch, err } @@ -149,7 +182,13 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { if nextAddr == "" { nextAddr = r.RemoteAddress } - r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) + + if neigh := r.nic.neigh; neigh != nil { + neigh.removeWaker(nextAddr, waker) + return + } + + r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) } // IsResolutionRequired returns true if Resolve() must be called to resolve @@ -157,24 +196,27 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // // The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { - return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" + if r.nic.neigh != nil { + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + } + return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + numBytes := pkt.Size() - err := r.ref.ep.WritePacket(r, gso, params, pkt) + err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } return err } @@ -182,77 +224,75 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return 0, tcpip.ErrInvalidEndpointState } // WritePackets takes ownership of pkt, calculate length first. numPkts := pkts.Len() - n, err := r.ref.ep.WritePackets(r, gso, pkts, params) + n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) } - r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) + r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Header.UsedLength() - writtenBytes += pb.Data.Size() + writtenBytes += pb.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.ref.isValidForOutgoing() { + if !r.nic.isValidForOutgoing(r.addressEndpoint) { return tcpip.ErrInvalidEndpointState } // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Data.Size() - if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { + if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } - r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + r.nic.stats.Tx.Packets.Increment() + r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.ref.ep.DefaultTTL() + return r.addressEndpoint.NetworkEndpoint().DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.ref.ep.MTU() + return r.addressEndpoint.NetworkEndpoint().MTU() } // NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying // network endpoint. func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.ref.ep.NetworkProtocolNumber() + return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.ref != nil { - r.ref.decRef() - r.ref = nil + if r.addressEndpoint != nil { + r.addressEndpoint.DecRef() + r.addressEndpoint = nil } } -// Clone Clone a route such that the original one can be released and the new -// one will remain valid. +// Clone clones the route. func (r *Route) Clone() Route { - if r.ref != nil { - r.ref.incRef() + if r.addressEndpoint != nil { + _ = r.addressEndpoint.IncRef() } return *r } @@ -276,13 +316,30 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.ref.stack() + return r.nic.stack +} + +func (r *Route) isV4Broadcast(addr tcpip.Address) bool { + if addr == header.IPv4Broadcast { + return true + } + + subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + return subnet.IsBroadcast(addr) +} + +// IsOutboundBroadcast returns true if the route is for an outbound broadcast +// packet. +func (r *Route) IsOutboundBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.isV4Broadcast(r.RemoteAddress) } -// IsBroadcast returns true if the route is to send a broadcast packet. -func (r *Route) IsBroadcast() bool { +// IsInboundBroadcast returns true if the route is for an inbound broadcast +// packet. +func (r *Route) IsInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast + return r.isV4Broadcast(r.LocalAddress) } // ReverseRoute returns new route with given source and destination address. @@ -293,7 +350,10 @@ func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { LocalLinkAddress: r.RemoteLinkAddress, RemoteAddress: src, RemoteLinkAddress: r.LocalLinkAddress, - ref: r.ref, Loop: r.Loop, + addressEndpoint: r.addressEndpoint, + nic: r.nic, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3f07e4159..57d8e79e0 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -73,6 +73,16 @@ type TCPCubicState struct { WEst float64 } +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +type TCPRACKState struct { + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool +} + // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. type TCPEndpointID struct { // LocalPort is the local port associated with the endpoint. @@ -134,10 +144,7 @@ type TCPReceiverState struct { // PendingBufUsed is the number of bytes pending in the receive // queue. - PendingBufUsed seqnum.Size - - // PendingBufSize is the size of the socket receive buffer. - PendingBufSize seqnum.Size + PendingBufUsed int } // TCPSenderState holds a copy of the internal state of the sender for @@ -212,6 +219,9 @@ type TCPSenderState struct { // Cubic holds the state related to CUBIC congestion control. Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. @@ -235,7 +245,7 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to userspace since + // CopiedBytes is the number of bytes copied to user space since // this measure began. CopiedBytes int @@ -353,38 +363,6 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } -// NICNameFromID is a function that returns a stable name for the specified NIC, -// even if different NIC IDs are used to refer to the same NIC in different -// program runs. It is used when generating opaque interface identifiers (IIDs). -// If the NIC was created with a name, it will be passed to NICNameFromID. -// -// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. -type NICNameFromID func(tcpip.NICID, string) string - -// OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. -type OpaqueInterfaceIdentifierOptions struct { - // NICNameFromID is a function that returns a stable name for a specified NIC, - // even if the NIC ID changes over time. - // - // Must be specified to generate the opaque IID. - NICNameFromID NICNameFromID - - // SecretKey is a pseudo-random number used as the secret key when generating - // opaque IIDs as defined by RFC 7217. The key SHOULD be at least - // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness - // requirements for security as outlined by RFC 4086. SecretKey MUST NOT - // change between program runs, unless explicitly changed. - // - // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey - // MUST NOT be modified after Stack is created. - // - // May be nil, but a nil value is highly discouraged to maintain - // some level of randomness between nodes. - SecretKey []byte -} - // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -402,10 +380,12 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool - cleanupEndpoints map[TransportEndpoint]struct{} + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + + // cleanupEndpointsMu protects cleanupEndpoints. + cleanupEndpointsMu sync.Mutex + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -416,7 +396,7 @@ type Stack struct { // If not nil, then any new endpoints will have this probe function // invoked everytime they receive a TCP segment. - tcpProbeFunc TCPProbeFunc + tcpProbeFunc atomic.Value // TCPProbeFunc // clock is used to generate user-visible times. clock tcpip.Clock @@ -442,20 +422,12 @@ type Stack struct { // TODO(gvisor.dev/issue/940): S/R this field. seed uint32 - // ndpConfigs is the default NDP configurations used by interfaces. - ndpConfigs NDPConfigurations - // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations - // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher + // useNeighborCache indicates whether ARP and NDP packets should be handled + // by the NIC's neighborCache instead of linkAddrCache. + useNeighborCache bool // nudDisp is the NUD event dispatcher that is used to send the netstack // integrator NUD related events. @@ -464,14 +436,6 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -494,13 +458,25 @@ type UniqueID interface { UniqueID() uint64 } +// NetworkProtocolFactory instantiates a network protocol. +// +// NetworkProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type NetworkProtocolFactory func(*Stack) NetworkProtocol + +// TransportProtocolFactory instantiates a transport protocol. +// +// TransportProtocolFactory must not attempt to modify the stack, it may only +// query the stack. +type TransportProtocolFactory func(*Stack) TransportProtocol + // Options contains optional Stack configuration. type Options struct { // NetworkProtocols lists the network protocols to enable. - NetworkProtocols []NetworkProtocol + NetworkProtocols []NetworkProtocolFactory // TransportProtocols lists the transport protocols to enable. - TransportProtocols []TransportProtocol + TransportProtocols []TransportProtocolFactory // Clock is an optional clock source used for timestampping packets. // @@ -518,33 +494,15 @@ type Options struct { // UniqueID is an optional generator of unique identifiers. UniqueID UniqueID - // NDPConfigs is the default NDP configurations used by interfaces. - // - // By default, NDPConfigs will have a zero value for its - // DupAddrDetectTransmits field, implying that DAD will not be performed - // before assigning an address to a NIC. - NDPConfigs NDPConfigurations - // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations - // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. - // - // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address Detection - // is enabled, an address will only be assigned if it successfully resolves. - // If it fails, no further attempt will be made to auto-generate an IPv6 - // link-local address. - // - // The generated link-local address will follow RFC 4291 Appendix A - // guidelines. - AutoGenIPv6LinkLocal bool - - // NDPDisp is the NDP event dispatcher that an integrator can provide to - // receive NDP related events. - NDPDisp NDPDispatcher + // UseNeighborCache indicates whether ARP and NDP packets should be handled + // by the Neighbor Unreachability Detection (NUD) state machine. This flag + // also enables the APIs for inspecting and modifying the neighbor table via + // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor, + // and ClearNeighbors. + UseNeighborCache bool // NUDDisp is the NUD event dispatcher that an integrator can provide to // receive NUD related events. @@ -554,31 +512,12 @@ type Options struct { // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface - // identifiers (IIDs) as outlined by RFC 7217. - OpaqueIIDOpts OpaqueInterfaceIdentifierOptions - // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data // returned by rand.Read(). // // RandSource must be thread-safe. RandSource mathrand.Source - - // TempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable - // and random from the perspective of other nodes on the network. It is - // recommended that the seed be a random byte buffer of at least - // header.IIDSize bytes to make sure that temporary SLAAC addresses are - // sufficiently random. It should follow minimum randomness requirements for - // security as outlined by RFC 4086. - // - // Note: using a nil value, the same seed across netstack program runs, or a - // seed that is too small would reduce randomness and increase predictability, - // defeating the purpose of temporary SLAAC addresses. - TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -681,35 +620,28 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } - // Make sure opts.NDPConfigs contains valid values only. - opts.NDPConfigs.validate() - opts.NUDConfigs.resetInvalidFields() s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), - nics: make(map[tcpip.NICID]*NIC), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: DefaultTables(), - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - ndpConfigs: opts.NDPConfigs, - nudConfigs: opts.NUDConfigs, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, - uniqueIDGenerator: opts.UniqueID, - ndpDisp: opts.NDPDisp, - nudDisp: opts.NUDDisp, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - forwarder: newForwardQueue(), - randomGenerator: mathrand.New(randSrc), + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: DefaultTables(), + icmpRateLimiter: NewICMPRateLimiter(), + seed: generateRandUint32(), + nudConfigs: opts.NUDConfigs, + useNeighborCache: opts.UseNeighborCache, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -723,7 +655,8 @@ func New(opts Options) *Stack { } // Add specified network protocols. - for _, netProto := range opts.NetworkProtocols { + for _, netProtoFactory := range opts.NetworkProtocols { + netProto := netProtoFactory(s) s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -731,7 +664,8 @@ func New(opts Options) *Stack { } // Add specified transport protocols. - for _, transProto := range opts.TransportProtocols { + for _, transProtoFactory := range opts.TransportProtocols { + transProto := transProtoFactory(s) s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } @@ -760,7 +694,7 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { return tcpip.ErrUnknownProtocol @@ -777,7 +711,7 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { return tcpip.ErrUnknownProtocol @@ -789,7 +723,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -804,7 +738,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -838,46 +772,37 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs for the +// passed protocol. +func (s *Stack) SetForwarding(protocolNum tcpip.NetworkProtocolNumber, enable bool) *tcpip.Error { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return tcpip.ErrUnknownProtocol + } - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { - return + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return tcpip.ErrNotSupported } - s.forwarding = enable + forwardingProtocol.SetForwarding(enable) + return nil +} - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { - return +// Forwarding returns true if packet forwarding between NICs is enabled for the +// passed protocol. +func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { + protocol, ok := s.networkProtocols[protocolNum] + if !ok { + return false } - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() - } + forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + if !ok { + return false } -} -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding + return forwardingProtocol.Forwarding() } // SetRouteTable assigns the route table to be used by this stack. It @@ -912,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewEndpoint(s, network, waiterQueue) + return t.proto.NewEndpoint(network, waiterQueue) } // NewRawEndpoint creates a new raw transport layer endpoint of the given @@ -932,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewRawEndpoint(s, network, waiterQueue) + return t.proto.NewRawEndpoint(network, waiterQueue) } // NewPacketEndpoint creates a new packet endpoint listening for the given @@ -1039,7 +964,8 @@ func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - return nic.disable() + nic.disable() + return nil } // CheckNIC checks if a NIC is usable. @@ -1052,7 +978,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } - return nic.enabled() + return nic.Enabled() } // RemoveNIC removes NIC and all related routes from the network stack. @@ -1089,19 +1015,6 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { return nic.remove() } -// NICAddressRanges returns a map of NICIDs to their associated subnets. -func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { - s.mu.RLock() - defer s.mu.RUnlock() - - nics := map[tcpip.NICID][]tcpip.Subnet{} - - for id, nic := range s.nics { - nics[id] = append(nics[id], nic.AddressRanges()...) - } - return nics -} - // NICInfo captures the name and addresses assigned to a NIC. type NICInfo struct { Name string @@ -1143,14 +1056,14 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.enabled(), + Running: nic.Enabled(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.isLoopback(), + Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.PrimaryAddresses(), + ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -1209,41 +1122,12 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } - return nic.AddAddress(protocolAddress, peb) -} - -// AddAddressRange adds a range of addresses to the specified NIC. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.AddAddressRange(protocol, subnet) - return nil - } - - return tcpip.ErrUnknownNICID -} - -// RemoveAddressRange removes the range of addresses from the specified NIC. -func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.RemoveAddressRange(subnet) - return nil - } - - return tcpip.ErrUnknownNICID + return nic.addAddress(protocolAddress, peb) } // RemoveAddress removes an existing network-layer address from the specified @@ -1253,7 +1137,7 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + return nic.removeAddress(addr) } return tcpip.ErrUnknownNICID @@ -1267,7 +1151,7 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) for id, nic := range s.nics { - nics[id] = nic.AllAddresses() + nics[id] = nic.allPermanentAddresses() } return nics } @@ -1289,7 +1173,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint { if len(localAddr) == 0 { return nic.primaryEndpoint(netProto, remoteAddr) } @@ -1306,9 +1190,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil + if nic, ok := s.nics[id]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil } } } else { @@ -1316,22 +1200,20 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { - if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { + if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. - remoteAddr = ref.ep.ID().LocalAddress + remoteAddr = addressEndpoint.AddressWithPrefix().Address } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) - + r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) if len(route.Gateway) > 0 { if needRoute { r.NextHop = route.Gateway } - } else if r.directedBroadcast { + } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } @@ -1364,26 +1246,25 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto // If a NIC is specified, we try to find the address there only. if nicID != 0 { - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return 0 } - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref == nil { + addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if addressEndpoint == nil { return 0 } - ref.decRef() + addressEndpoint.DecRef() return nic.id } // Go through all the NICs. for _, nic := range s.nics { - ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) - if ref != nil { - ref.decRef() + if addressEndpoint := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() return nic.id } } @@ -1396,8 +1277,8 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return tcpip.ErrUnknownNICID } @@ -1412,8 +1293,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return tcpip.ErrUnknownNICID } @@ -1445,8 +1326,33 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } -// RemoveWaker implements LinkAddressCache.RemoveWaker. +// Neighbors returns all IP to MAC address associations. +func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return nil, tcpip.ErrUnknownNICID + } + + return nic.neighbors() +} + +// RemoveWaker removes a waker that has been added when link resolution for +// addr was requested. func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + if s.useNeighborCache { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if ok { + nic.removeWaker(addr, waker) + } + return + } + s.mu.RLock() defer s.mu.RUnlock() @@ -1456,6 +1362,47 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. } } +// AddStaticNeighbor statically associates an IP address to a MAC address. +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.addStaticNeighbor(addr, linkAddr) +} + +// RemoveNeighbor removes an IP to MAC address association previously created +// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there +// is no association with the provided address. +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.removeNeighbor(addr) +} + +// ClearNeighbors removes all IP to MAC address associations. +func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.clearNeighbors() +} + // RegisterTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but @@ -1479,10 +1426,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { - s.mu.Lock() - defer s.mu.Unlock() - + s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} + s.cleanupEndpointsMu.Unlock() s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } @@ -1490,9 +1436,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp // CompleteTransportEndpointCleanup removes the endpoint from the cleanup // stage. func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() delete(s.cleanupEndpoints, ep) - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // FindTransportEndpoint finds an endpoint that most closely matches the provided @@ -1535,23 +1481,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint { // CleanupEndpoints returns endpoints currently in the cleanup state. func (s *Stack) CleanupEndpoints() []TransportEndpoint { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) for e := range s.cleanupEndpoints { es = append(es, e) } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() return es } // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() for _, e := range es { s.cleanupEndpoints[e] = struct{}{} } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // Close closes all currently registered transport endpoints. @@ -1746,18 +1692,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra // guarantee provided on which probe will be invoked. Ideally this should only // be called once per stack. func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { - s.mu.Lock() - s.tcpProbeFunc = probe - s.mu.Unlock() + s.tcpProbeFunc.Store(probe) } // GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil // otherwise. func (s *Stack) GetTCPProbe() TCPProbeFunc { - s.mu.Lock() - p := s.tcpProbeFunc - s.mu.Unlock() - return p + p := s.tcpProbeFunc.Load() + if p == nil { + return nil + } + return p.(TCPProbeFunc) } // RemoveTCPProbe removes an installed TCP probe. @@ -1766,9 +1711,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc { // have a probe attached. Endpoints already created will continue to invoke // TCP probe. func (s *Stack) RemoveTCPProbe() { - s.mu.Lock() - s.tcpProbeFunc = nil - s.mu.Unlock() + // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics. + s.tcpProbeFunc.Store(TCPProbeFunc(nil)) } // JoinGroup joins the given multicast group on the given NIC. @@ -1789,7 +1733,7 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC defer s.mu.RUnlock() if nic, ok := s.nics[nicID]; ok { - return nic.leaveGroup(multicastAddr) + return nic.leaveGroup(protocol, multicastAddr) } return tcpip.ErrUnknownNICID } @@ -1841,53 +1785,18 @@ func (s *Stack) AllowICMPMessage() bool { return s.icmpRateLimiter.Allow() } -// IsAddrTentative returns true if addr is tentative on the NIC with ID id. -// -// Note that if addr is not associated with a NIC with id ID, then this -// function will return false. It will only return true if the address is -// associated with the NIC AND it is tentative. -func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - nic, ok := s.nics[id] - if !ok { - return false, tcpip.ErrUnknownNICID - } - - return nic.isAddrTentative(addr), nil -} - -// DupTentativeAddrDetected attempts to inform the NIC with ID id that a -// tentative addr on it is a duplicate on a link. -func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - - nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - return nic.dupTentativeAddrDetected(addr) -} - -// SetNDPConfigurations sets the per-interface NDP configurations on the NIC -// with ID id to c. -// -// Note, if c contains invalid NDP configuration values, it will be fixed to -// use default values for the erroneous values. -func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error { +// GetNetworkEndpoint returns the NetworkEndpoint with the specified protocol +// number installed on the specified NIC. +func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber) (NetworkEndpoint, *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() - nic, ok := s.nics[id] + nic, ok := s.nics[nicID] if !ok { - return tcpip.ErrUnknownNICID + return nil, tcpip.ErrUnknownNICID } - nic.setNDPConfigs(c) - return nil + return nic.networkEndpoints[proto], nil } // NUDConfigurations gets the per-interface NUD configurations. @@ -1900,7 +1809,7 @@ func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Err return NUDConfigurations{}, tcpip.ErrUnknownNICID } - return nic.NUDConfigs() + return nic.nudConfigs() } // SetNUDConfigurations sets the per-interface NUD configurations. @@ -1919,22 +1828,6 @@ func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip return nic.setNUDConfigs(c) } -// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement -// message that it needs to handle. -func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { - s.mu.Lock() - defer s.mu.Unlock() - - nic, ok := s.nics[id] - if !ok { - return tcpip.ErrUnknownNICID - } - - nic.handleNDPRA(ip, ra) - - return nil -} - // Seed returns a 32 bit value that can be used as a seed value for port // picking, ISN generation etc. // @@ -1972,28 +1865,26 @@ func generateRandInt64() int64 { // FindNetworkEndpoint returns the network endpoint for the given address. func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() for _, nic := range s.nics { - id := NetworkEndpointID{address} - - if ref, ok := nic.mu.endpoints[id]; ok { - nic.mu.RLock() - defer nic.mu.RUnlock() - - // An endpoint with this id exists, check if it can be - // used and return it. - return ref.ep, nil + addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + continue } + + ep := addressEndpoint.NetworkEndpoint() + addressEndpoint.DecRef() + return ep, nil } return nil, tcpip.ErrBadAddress } -// FindNICNameFromID returns the name of the nic for the given NICID. +// FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { @@ -2002,3 +1893,8 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { return nic.Name() } + +// NewJob returns a new tcpip.Job using the stack's clock. +func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f22062889..aa20f750b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,19 +21,21 @@ import ( "bytes" "fmt" "math" + "net" "sort" - "strings" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -67,40 +69,53 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { + stack.AddressableEndpointState + + mu struct { + sync.RWMutex + + enabled bool + } + nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint } -func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.ep.MTU() - uint32(f.MaxHeaderLength()) +func (f *fakeNetworkEndpoint) Enable() *tcpip.Error { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = true + return nil } -func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicID +func (f *fakeNetworkEndpoint) Enabled() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.enabled } -func (f *fakeNetworkEndpoint) PrefixLen() int { - return f.prefixLen +func (f *fakeNetworkEndpoint) Disable() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.enabled = false } -func (*fakeNetworkEndpoint) DefaultTTL() uint8 { - return 123 +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id +func (*fakeNetworkEndpoint) DefaultTTL() uint8 { + return 123 } func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ + f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -116,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -127,10 +142,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto return 0 } -func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.ep.Capabilities() -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -141,10 +152,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) - pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] - pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] - pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = r.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { f.HandlePacket(r, pkt) @@ -165,16 +176,8 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack return tcpip.ErrNotSupported } -func (*fakeNetworkEndpoint) Close() {} - -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool +func (f *fakeNetworkEndpoint) Close() { + f.AddressableEndpointState.Cleanup() } // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the @@ -183,7 +186,12 @@ type fakeNetOptions struct { type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int - opts fakeNetOptions + defaultTTL uint8 + + mu struct { + sync.RWMutex + forwarding bool + } } func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -206,57 +214,67 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - return &fakeNetworkEndpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, +func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { + e := &fakeNetworkEndpoint{ + nicID: nic.ID(), proto: f, dispatcher: dispatcher, - ep: ep, - }, nil + ep: nic.LinkEndpoint(), + } + e.AddressableEndpointState.Init(e) + return e } -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) + case *tcpip.DefaultTTLOption: + f.defaultTTL = uint8(*v) return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: return tcpip.ErrUnknownProtocolOption } } -// Close implements TransportProtocol.Close. +// Close implements NetworkProtocol.Close. func (*fakeNetworkProtocol) Close() {} -// Wait implements TransportProtocol.Wait. +// Wait implements NetworkProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} -// Parse implements TransportProtocol.Parse. +// Parse implements NetworkProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(fakeNetHeaderLen) return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -func fakeNetFactory() stack.NetworkProtocol { +// Forwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) Forwarding() bool { + f.mu.RLock() + defer f.mu.RUnlock() + return f.mu.forwarding +} + +// SetForwarding implements stack.ForwardingNetworkProtocol. +func (f *fakeNetworkProtocol) SetForwarding(v bool) { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.forwarding = v +} + +func fakeNetFactory(*stack.Stack) stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -277,12 +295,23 @@ func (l *linkEPWithMockedAttach) isAttached() bool { return l.attached } +// Checks to see if list contains an address. +func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { + for _, i := range list { + if i == item { + return true + } + } + + return false +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) @@ -302,9 +331,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -314,9 +343,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -326,9 +355,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -337,9 +366,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -349,9 +378,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -370,11 +399,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -429,9 +457,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -443,7 +471,7 @@ func TestNetworkSend(t *testing.T) { // existing nic. ep := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) @@ -470,7 +498,7 @@ func TestNetworkSendMultiRoute(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -570,7 +598,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -589,7 +617,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { @@ -601,7 +629,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := loopback.New() @@ -648,7 +676,7 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { @@ -660,7 +688,7 @@ func TestRemoveNIC(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) e := linkEPWithMockedAttach{ @@ -721,7 +749,7 @@ func TestRouteWithDownNIC(t *testing.T) { setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(1, defaultMTU, "") @@ -887,7 +915,7 @@ func TestRoutes(t *testing.T) { // addresses per nic, the first nic has odd address, the second one has // even addresses. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") @@ -967,7 +995,7 @@ func TestAddressRemoval(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1014,7 +1042,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { remoteAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1105,7 +1133,7 @@ func TestEndpointExpiration(t *testing.T) { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1263,7 +1291,7 @@ func TestEndpointExpiration(t *testing.T) { func TestPromiscuousMode(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1315,7 +1343,7 @@ func TestSpoofingWithAddress(t *testing.T) { dstAddr := tcpip.Address("\x03") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1381,7 +1409,7 @@ func TestSpoofingNoAddress(t *testing.T) { dstAddr := tcpip.Address("\x02") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1444,7 +1472,7 @@ func verifyRoute(gotRoute, wantRoute stack.Route) error { func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1487,7 +1515,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // Create a new stack with two NICs. s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1588,7 +1616,7 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") @@ -1643,239 +1671,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Add a range of addresses, then check that a packet is delivered. -func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { - t.Helper() - - // Loop over all addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - - addrBytes := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - addr := tcpip.Address(addrBytes) - wantNicID := nicID - // The subnet and broadcast addresses are skipped. - if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { - wantNicID = 0 - } - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) - } - addrBytes[0]++ - } - - // Trying the next address should always fail since it is outside the range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) - } -} - -// Set a range of addresses, then remove it again, and check at each step that -// CheckLocalAddress returns the correct NIC for each address or zero if not -// existent. -func TestCheckLocalAddressForSubnet(t *testing.T) { - const nicID tcpip.NICID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) - - if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) - - if err := s.RemoveAddressRange(nicID, subnet); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) -} - -// Set a range of addresses, then send a packet to a destination outside the -// range and then check it doesn't get delivered. -func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { +func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{}, }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) + opt := tcpip.DefaultTTLOption(5) + if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { + t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) } - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + var optGot tcpip.DefaultTTLOption + if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { + t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func TestNetworkOptions(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{}, - }) - - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } - } -} - -func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { - ranges, ok := s.NICAddressRanges()[id] - if !ok { - return false - } - for _, r := range ranges { - if r == addrRange { - return true - } - } - return false -} - -func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - addrRange, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.RemoveAddressRange(1, addrRange); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) + if opt != optGot { + t.Errorf("got optGot = %d, want = %d", optGot, opt) } } @@ -1887,7 +1700,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -1954,7 +1767,7 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { func TestGetMainNICAddressAddRemove(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep); err != nil { @@ -2039,7 +1852,7 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2066,7 +1879,7 @@ func TestAddAddress(t *testing.T) { func TestAddProtocolAddress(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2100,7 +1913,7 @@ func TestAddProtocolAddress(t *testing.T) { func TestAddAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2131,7 +1944,7 @@ func TestAddAddressWithOptions(t *testing.T) { func TestAddProtocolAddressWithOptions(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID, ep); err != nil { @@ -2252,7 +2065,7 @@ func TestCreateNICWithOptions(t *testing.T) { func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -2272,9 +2085,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -2319,9 +2132,9 @@ func TestNICForwarding(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { @@ -2354,9 +2167,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) pkt, ok := ep2.Read() if !ok { @@ -2364,8 +2177,8 @@ func TestNICForwarding(t *testing.T) { } // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) } // Test that forwarding increments Tx stats correctly. @@ -2443,7 +2256,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName string autoGen bool linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions + iidOpts ipv6.OpaqueInterfaceIdentifierOptions shouldGen bool expectedAddr tcpip.Address }{ @@ -2459,7 +2272,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: false, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2504,7 +2317,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "nic1", autoGen: true, linkAddr: linkAddr1, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:], }, @@ -2516,7 +2329,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { { name: "OIID Empty MAC and empty nicName", autoGen: true, - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:1], }, @@ -2528,7 +2341,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test", autoGen: true, linkAddr: "\x01\x02\x03", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:2], }, @@ -2540,7 +2353,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test2", autoGen: true, linkAddr: "\x01\x02\x03\x04\x05\x06", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, SecretKey: secretKey[:3], }, @@ -2552,7 +2365,7 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { nicName: "test3", autoGen: true, linkAddr: "\x00\x00\x00\x00\x00\x00", - iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + iidOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: nicNameFunc, }, shouldGen: true, @@ -2566,10 +2379,11 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, + })}, } e := channel.New(0, 1280, test.linkAddr) @@ -2641,15 +2455,15 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { tests := []struct { name string - opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + opaqueIIDOpts ipv6.OpaqueInterfaceIdentifierOptions }{ { name: "IID From MAC", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{}, }, { name: "Opaque IID", - opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + opaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName }, @@ -2660,9 +2474,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + })}, } e := loopback.New() @@ -2691,12 +2506,13 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - ndpConfigs := stack.DefaultNDPConfigurations() + ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + })}, } e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2752,7 +2568,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { for _, ps := range pebs { t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, ep1); err != nil { @@ -3043,14 +2859,15 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDispatcher{}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, + })}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3089,59 +2906,58 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { const nicID = 1 + broadcastAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } e := loopback.New() s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) nicOpts := stack.NICOptions{Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) } - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } // Enabling the NIC should add the IPv4 broadcast address. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 1 { - t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) - } - want := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - if allNICAddrs[0] != want { - t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if !containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) + } } // Disabling the NIC should remove the IPv4 broadcast address. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } } @@ -3152,7 +2968,7 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { const nicID = 1 s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, }) e := channel.New(10, 1280, linkAddr1) if err := s.CreateNIC(1, e); err != nil { @@ -3189,50 +3005,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { } } -func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { +func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + name: "IPv6 All-Nodes", + proto: header.IPv6ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + }, + { + name: "IPv4 All-Systems", + proto: header.IPv4ProtocolNumber, + addr: header.IPv4AllSystems, + }, } - // Should not be in the IPv6 all-nodes multicast group yet because the NIC has - // not been enabled yet. - isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) - } + // Should not be in the multicast group yet because the NIC has not been + // enabled yet. + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } - // The all-nodes multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // The multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // Leaving the group before disabling the NIC should not cause an error. + if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + }) } } @@ -3247,12 +3106,13 @@ func TestDoDADWhenNICEnabled(t *testing.T) { dadC: make(chan ndpDADEvent), } opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - }, - NDPDisp: &ndpDisp, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + })}, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3611,7 +3471,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, }) ep := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, ep); err != nil { @@ -3641,3 +3501,177 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } + +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} + +// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its +// associated address is removed should not cause a panic. +func TestRouteReleaseAfterAddrRemoval(t *testing.T) { + const ( + nicID = 1 + localAddr = tcpip.Address("\x01") + remoteAddr = tcpip.Address("\x02") + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + r, err := s.FindRoute(nicID, localAddr, remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, localAddr, remoteAddr, fakeNetNumber, err) + } + // Should not panic. + defer r.Release() + + // Check that removing the same address fails. + if err := s.RemoveAddress(nicID, localAddr); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, localAddr, err) + } +} + +func TestGetNetworkEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + }, + } + + factories := make([]stack.NetworkProtocolFactory, 0, len(tests)) + for _, test := range tests { + factories = append(factories, test.protoFactory) + } + + s := stack.New(stack.Options{ + NetworkProtocols: factories, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep, err := s.GetNetworkEndpoint(nicID, test.protoNum) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, test.protoNum, err) + } + + if got := ep.NetworkProtocolNumber(); got != test.protoNum { + t.Fatalf("got ep.NetworkProtocolNumber() = %d, want = %d", got, test.protoNum) + } + }) + } +} + +func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) + + if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: fakeNetNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err) + } + + // Check that we get the right initial address and prefix length. + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } + + // Should still get the address when the NIC is diabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err) + } else if gotAddr != protocolAddress.AddressWithPrefix { + t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..35e5b1a2e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -155,7 +155,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -544,9 +544,11 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return true } - // If the packet is a TCP packet with a non-unicast source or destination - // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // If the packet is a TCP packet with a unspecified source or non-unicast + // destination address, then do nothing further and instruct the caller to do + // the same. The network layer handles address validation for specified source + // addresses. + if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -626,7 +628,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.nic.ID()] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -677,10 +679,10 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isSpecified(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 73dada928..698c8609e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -51,8 +51,8 @@ type testContext struct { // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { @@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) } func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { @@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { @@ -184,8 +182,8 @@ func TestTransportDemuxerRegister(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) @@ -314,8 +312,8 @@ func TestBindToDeviceDistribution(t *testing.T) { t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) } var dstAddr tcpip.Address diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 7e8b84867..62ab6d92f 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -39,7 +39,7 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo - stack *stack.Stack + proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route @@ -53,14 +53,14 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { return &f.TransportEndpointInfo } -func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { +func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} +func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Abort() { @@ -84,28 +84,28 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) - hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } @@ -130,11 +130,7 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } @@ -147,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { f.peerAddr = addr.Addr // Find the route. - r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) + r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */) if err != nil { return tcpip.ErrNoRoute } @@ -155,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) + err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -184,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -194,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip. } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { - if err := f.stack.RegisterTransportEndpoint( + if err := f.proto.stack.RegisterTransportEndpoint( a.NIC, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, @@ -222,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE f.proto.packetCount++ if f.acceptQueue != nil { f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - stack: f.stack, TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -239,19 +234,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } -func (f *fakeTransportEndpoint) State() uint32 { +func (*fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} +func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { - return stack.IPTables{}, nil -} +func (*fakeTransportEndpoint) Resume(*stack.Stack) {} -func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} +func (*fakeTransportEndpoint) Wait() {} -func (f *fakeTransportEndpoint) Wait() {} +func (*fakeTransportEndpoint) LastError() *tcpip.Error { + return nil +} type fakeTransportGoodOption bool @@ -266,6 +261,8 @@ type fakeTransportProtocolOptions struct { // fakeTransportProtocol is a transport-layer protocol descriptor. It // aggregates the number of packets received via endpoints of this protocol. type fakeTransportProtocol struct { + stack *stack.Stack + packetCount int controlCount int opts fakeTransportProtocolOptions @@ -275,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { return fakeTransNumber } -func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil +func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil } -func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -291,26 +288,24 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) + case *tcpip.TCPModerateReceiveBufferOption: + f.opts.good = bool(*v) return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) + case *tcpip.TCPModerateReceiveBufferOption: + *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: return tcpip.ErrUnknownProtocolOption @@ -328,24 +323,19 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) - if !ok { - return false - } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(fakeTransHeaderLen) - return true + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok } -func fakeTransFactory() stack.TransportProtocol { - return &fakeTransportProtocol{} +func fakeTransFactory(s *stack.Stack) stack.TransportProtocol { + return &fakeTransportProtocol{stack: s} } func TestTransportReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -382,9 +372,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -393,9 +383,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -404,9 +394,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -415,8 +405,8 @@ func TestTransportReceive(t *testing.T) { func TestTransportControlReceive(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -459,9 +449,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -470,9 +460,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -481,9 +471,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -492,8 +482,8 @@ func TestTransportControlReceive(t *testing.T) { func TestTransportSend(t *testing.T) { linkEP := channel.New(10, defaultMTU, "") s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -538,54 +528,29 @@ func TestTransportSend(t *testing.T) { func TestTransportOptions(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } + v := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) + } + v = false + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) + } + if !v { + t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } func TestTransportForwarding(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() @@ -636,11 +601,11 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: req.ToVectorisedView(), - }) + })) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err != nil || aep == nil { t.Fatalf("Accept failed: %v, %v", aep, err) } @@ -655,10 +620,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.NetworkHeader[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.NetworkHeader[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } |