summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD15
-rw-r--r--pkg/tcpip/stack/ndp.go534
-rw-r--r--pkg/tcpip/stack/ndp_test.go1013
-rw-r--r--pkg/tcpip/stack/nic.go421
-rw-r--r--pkg/tcpip/stack/registration.go177
-rw-r--r--pkg/tcpip/stack/route.go63
-rw-r--r--pkg/tcpip/stack/stack.go527
-rw-r--r--pkg/tcpip/stack/stack_test.go563
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go373
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go354
-rw-r--r--pkg/tcpip/stack/transport_test.go81
11 files changed, 3683 insertions, 438 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 28c49e8ff..460db3cf8 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,10 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
load("//tools/go_generics:defs.bzl", "go_template_instance")
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_template_instance(
name = "linkaddrentry_list",
out = "linkaddrentry_list.go",
@@ -23,6 +22,7 @@ go_library(
"icmp_rate_limit.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
+ "ndp.go",
"nic.go",
"registration.go",
"route.go",
@@ -36,6 +36,7 @@ go_library(
],
deps = [
"//pkg/ilist",
+ "//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -53,18 +54,26 @@ go_test(
name = "stack_x_test",
size = "small",
srcs = [
+ "ndp_test.go",
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
":stack",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go-cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
new file mode 100644
index 000000000..8e49f7a56
--- /dev/null
+++ b/pkg/tcpip/stack/ndp.go
@@ -0,0 +1,534 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "log"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ // defaultDupAddrDetectTransmits is the default number of NDP Neighbor
+ // Solicitation messages to send when doing Duplicate Address Detection
+ // for a tentative address.
+ //
+ // Default = 1 (from RFC 4862 section 5.1)
+ defaultDupAddrDetectTransmits = 1
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending NDP Neighbor solicitation messages.
+ //
+ // Default = 1s (from RFC 4861 section 10).
+ defaultRetransmitTimer = time.Second
+
+ // defaultHandleRAs is the default configuration for whether or not to
+ // handle incoming Router Advertisements as a host.
+ //
+ // Default = true.
+ defaultHandleRAs = true
+
+ // defaultDiscoverDefaultRouters is the default configuration for
+ // whether or not to discover default routers from incoming Router
+ // Advertisements as a host.
+ //
+ // Default = true.
+ defaultDiscoverDefaultRouters = true
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
+ // not impose a minimum Retransmit Timer, but we do here to make sure
+ // the messages are not sent all at once. We also come to this value
+ // because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1.
+ // Note, the unit of the RetransmitTimer field in the Router
+ // Advertisement is milliseconds.
+ //
+ // Min = 1ms.
+ minimumRetransmitTimer = time.Millisecond
+
+ // MaxDiscoveredDefaultRouters is the maximum number of discovered
+ // default routers. The stack should stop discovering new routers after
+ // discovering MaxDiscoveredDefaultRouters routers.
+ //
+ // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
+ // SHOULD be more.
+ //
+ // Max = 10.
+ MaxDiscoveredDefaultRouters = 10
+)
+
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus will be called when the DAD process
+ // for an address (addr) on a NIC (with ID nicID) completes. resolved
+ // will be set to true if DAD completed successfully (no duplicate addr
+ // detected); false otherwise (addr was detected to be a duplicate on
+ // the link the NIC is a part of, or it was stopped for some other
+ // reason, such as the address being removed). If an error occured
+ // during DAD, err will be set and resolved must be ignored.
+ //
+ // This function is permitted to block indefinitely without interfering
+ // with the stack's operation.
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+
+ // OnDefaultRouterDiscovered will be called when a new default router is
+ // discovered. Implementations must return true along with a new valid
+ // route table if the newly discovered router should be remembered. If
+ // an implementation returns false, the second return value will be
+ // ignored.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route)
+
+ // OnDefaultRouterInvalidated will be called when a discovered default
+ // router is invalidated. Implementers must return a new valid route
+ // table.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route
+}
+
+// NDPConfigurations is the NDP configurations for the netstack.
+type NDPConfigurations struct {
+ // The number of Neighbor Solicitation messages to send when doing
+ // Duplicate Address Detection for a tentative address.
+ //
+ // Note, a value of zero effectively disables DAD.
+ DupAddrDetectTransmits uint8
+
+ // The amount of time to wait between sending Neighbor solicitation
+ // messages.
+ //
+ // Must be greater than 0.5s.
+ RetransmitTimer time.Duration
+
+ // HandleRAs determines whether or not Router Advertisements will be
+ // processed.
+ HandleRAs bool
+
+ // DiscoverDefaultRouters determines whether or not default routers will
+ // be discovered from Router Advertisements. This configuration is
+ // ignored if HandleRAs is false.
+ DiscoverDefaultRouters bool
+}
+
+// DefaultNDPConfigurations returns an NDPConfigurations populated with
+// default values.
+func DefaultNDPConfigurations() NDPConfigurations {
+ return NDPConfigurations{
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ }
+}
+
+// validate modifies an NDPConfigurations with valid values. If invalid values
+// are present in c, the corresponding default values will be used instead.
+//
+// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
+// defaultRetransmitTimer will be used.
+func (c *NDPConfigurations) validate() {
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+}
+
+// ndpState is the per-interface NDP state.
+type ndpState struct {
+ // The NIC this ndpState is for.
+ nic *NIC
+
+ // configs is the per-interface NDP configurations.
+ configs NDPConfigurations
+
+ // The DAD state to send the next NS message, or resolve the address.
+ dad map[tcpip.Address]dadState
+
+ // The default routers discovered through Router Advertisements.
+ defaultRouters map[tcpip.Address]defaultRouterState
+}
+
+// dadState holds the Duplicate Address Detection timer and channel to signal
+// to the DAD goroutine that DAD should stop.
+type dadState struct {
+ // The DAD timer to send the next NS message, or resolve the address.
+ timer *time.Timer
+
+ // Used to let the DAD timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this dadState is associated with.
+ done *bool
+}
+
+// defaultRouterState holds data associated with a default router discovered by
+// a Router Advertisement.
+type defaultRouterState struct {
+ invalidationTimer *time.Timer
+
+ // Used to signal the timer not to invalidate the default router (R) in
+ // a race condition (T1 is a goroutine that handles an RA from R and T2
+ // is the goroutine that handles R's invalidation timer firing):
+ // T1: Receive a new RA from R
+ // T1: Obtain the NIC's lock before processing the RA
+ // T2: R's invalidation timer fires, and gets blocked on obtaining the
+ // NIC's lock
+ // T1: Refreshes/extends R's lifetime & releases NIC's lock
+ // T2: Obtains NIC's lock & invalidates R immediately
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // signal the timer using this channel to not invalidate R, so that once
+ // T2 obtains the lock, it will see that there is an event on this
+ // channel and do nothing further.
+ doNotInvalidateC chan struct{}
+}
+
+// startDuplicateAddressDetection performs Duplicate Address Detection.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+ // addr must be a valid unicast IPv6 address.
+ if !header.IsV6UnicastAddress(addr) {
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in
+ // the DAD process.
+ if _, ok := ndp.dad[addr]; ok {
+ // Should never happen because we should only ever call this
+ // function for newly created addresses. If we attemped to
+ // "add" an address that already existed, we would returned an
+ // error since we attempted to add a duplicate address, or its
+ // reference count would have been increased without doing the
+ // work that would have been done for an address that was brand
+ // new. See NIC.addPermanentAddressLocked.
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ remaining := ndp.configs.DupAddrDetectTransmits
+
+ {
+ done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil {
+ return err
+ }
+ if done {
+ return nil
+ }
+ }
+
+ remaining--
+
+ var done bool
+ var timer *time.Timer
+ timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() {
+ var d bool
+ var err *tcpip.Error
+
+ // doDadIteration does a single iteration of the DAD loop.
+ //
+ // Returns true if the integrator needs to be informed of DAD
+ // completing.
+ doDadIteration := func() bool {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD
+ // timer fired after another goroutine already
+ // obtained the NIC lock and stopped DAD before
+ // this function obtained the NIC lock. Simply
+ // return here and do nothing further.
+ return false
+ }
+
+ ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ // This should never happen.
+ // We should have an endpoint for addr since we
+ // are still performing DAD on it. If the
+ // endpoint does not exist, but we are doing DAD
+ // on it, then we started DAD at some point, but
+ // forgot to stop it when the endpoint was
+ // deleted.
+ panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil || d {
+ delete(ndp.dad, addr)
+
+ if err != nil {
+ log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+
+ // Let the integrator know DAD has completed.
+ return true
+ }
+
+ remaining--
+ timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ return false
+ }
+
+ if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ }
+ })
+
+ ndp.dad[addr] = dadState{
+ timer: timer,
+ done: &done,
+ }
+
+ return nil
+}
+
+// doDuplicateAddressDetection is called on every iteration of the timer, and
+// when DAD starts.
+//
+// It handles resolving the address (if there are no more NS to send), or
+// sending the next NS if there are more NS to send.
+//
+// This function must only be called by IPv6 addresses that are currently
+// tentative.
+//
+// The NIC that ndp belongs to (n) MUST be locked.
+//
+// Returns true if DAD has resolved; false if DAD is still ongoing.
+func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative
+ // since we are still performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ if remaining == 0 {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ return true, nil
+ }
+
+ // Send a new NS.
+ snmc := header.SolicitedNodeAddr(addr)
+ snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
+ if !ok {
+ // This should never happen as if we have the
+ // address, we should have the solicited-node
+ // address.
+ panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
+ }
+
+ // Use the unspecified address as the source address when performing
+ // DAD.
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(addr)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil {
+ sent.Dropped.Increment()
+ return false, err
+ }
+ sent.NeighborSolicit.Increment()
+
+ return false, nil
+}
+
+// stopDuplicateAddressDetection ends a running Duplicate Address Detection
+// process. Note, this may leave the DAD process for a tentative address in
+// such a state forever, unless some other external event resolves the DAD
+// process (receiving an NA from the true owner of addr, or an NS for addr
+// (implying another node is attempting to use addr)). It is up to the caller
+// of this function to handle such a scenario. Normally, addr will be removed
+// from n right after this function returns or the address successfully
+// resolved.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
+ dad, ok := ndp.dad[addr]
+ if !ok {
+ // Not currently performing DAD on addr, just return.
+ return
+ }
+
+ if dad.timer != nil {
+ dad.timer.Stop()
+ dad.timer = nil
+
+ *dad.done = true
+ dad.done = nil
+ }
+
+ delete(ndp.dad, addr)
+
+ // Let the integrator know DAD did not resolve.
+ if ndp.nic.stack.ndpDisp != nil {
+ go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
+}
+
+// handleRA handles a Router Advertisement message that arrived on the NIC
+// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ // Is the NIC configured to handle RAs at all?
+ //
+ // Currently, the stack does not determine router interface status on a
+ // per-interface basis; it is a stack-wide configuration, so we check
+ // stack's forwarding flag to determine if the NIC is a routing
+ // interface.
+ if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ return
+ }
+
+ // Is the NIC configured to discover default routers?
+ if ndp.configs.DiscoverDefaultRouters {
+ rtr, ok := ndp.defaultRouters[ip]
+ rl := ra.RouterLifetime()
+ switch {
+ case !ok && rl != 0:
+ // This is a new default router we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredDefaultRouters routers.
+ if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
+ ndp.rememberDefaultRouter(ip, rl)
+ }
+
+ case ok && rl != 0:
+ // This is an already discovered default router. Update
+ // the invalidation timer.
+ timer := rtr.invalidationTimer
+
+ // We should ALWAYS have an invalidation timer for a
+ // discovered router.
+ if timer == nil {
+ panic("ndphandlera: RA invalidation timer should not be nil")
+ }
+
+ if !timer.Stop() {
+ // If we reach this point, then we know the
+ // timer fired after we already took the NIC
+ // lock. Signal the timer so that once it
+ // obtains the lock, it doesn't actually
+ // invalidate the router as we just got a new
+ // RA that refreshes its lifetime to a non-zero
+ // value. See
+ // defaultRouterState.doNotInvalidateC for more
+ // details.
+ rtr.doNotInvalidateC <- struct{}{}
+ }
+
+ timer.Reset(rl)
+
+ case ok && rl == 0:
+ // We know about the router but it is no longer to be
+ // used as a default router so invalidate it.
+ ndp.invalidateDefaultRouter(ip)
+ }
+ }
+
+ // TODO(b/140948104): Do Prefix Discovery.
+ // TODO(b/141556115): Do Parameter Discovery.
+}
+
+// invalidateDefaultRouter invalidates a discovered default router.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
+ rtr, ok := ndp.defaultRouters[ip]
+
+ // Is the router still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ rtr.invalidationTimer.Stop()
+ rtr.invalidationTimer = nil
+ close(rtr.doNotInvalidateC)
+ rtr.doNotInvalidateC = nil
+
+ delete(ndp.defaultRouters, ip)
+
+ // Let the integrator know a discovered default router is invalidated.
+ if ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ }
+}
+
+// rememberDefaultRouter remembers a newly discovered default router with IPv6
+// link-local address ip with lifetime rl.
+//
+// The router identified by ip MUST NOT already be known by the NIC.
+//
+// The NIC that ndp belongs to and its associated stack MUST be locked.
+func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+ if ndp.nic.stack.ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered a default router.
+ remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip)
+ if !remember {
+ // Informed by the integrator to not remember the router, do
+ // nothing further.
+ return
+ }
+
+ // Used to signal the timer not to invalidate the default router (R) in
+ // a race condition. See defaultRouterState.doNotInvalidateC for more
+ // details.
+ doNotInvalidateC := make(chan struct{}, 1)
+
+ ndp.defaultRouters[ip] = defaultRouterState{
+ invalidationTimer: time.AfterFunc(rl, func() {
+ ndp.nic.stack.mu.Lock()
+ defer ndp.nic.stack.mu.Unlock()
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ select {
+ case <-doNotInvalidateC:
+ return
+ default:
+ }
+
+ ndp.invalidateDefaultRouter(ip)
+ }),
+ doNotInvalidateC: doNotInvalidateC,
+ }
+
+ ndp.nic.stack.routeTable = routeTable
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
new file mode 100644
index 000000000..50ce1bbfa
--- /dev/null
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -0,0 +1,1013 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "encoding/binary"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ linkAddr1 = "\x02\x02\x03\x04\x05\x06"
+ linkAddr2 = "\x02\x02\x03\x04\x05\x07"
+ linkAddr3 = "\x02\x02\x03\x04\x05\x08"
+ defaultTimeout = 250 * time.Millisecond
+)
+
+var (
+ llAddr1 = header.LinkLocalAddr(linkAddr1)
+ llAddr2 = header.LinkLocalAddr(linkAddr2)
+ llAddr3 = header.LinkLocalAddr(linkAddr3)
+)
+
+// TestDADDisabled tests that an address successfully resolves immediately
+// when DAD is not enabled (the default for an empty stack.Options).
+func TestDADDisabled(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Should get the address immediately since we should not have performed
+ // DAD on it.
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ }
+
+ // We should not have sent any NDP NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ t.Fatalf("got NeighborSolicit = %d, want = 0", got)
+ }
+}
+
+// ndpDADEvent is a set of parameters that was passed to
+// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+type ndpDADEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ resolved bool
+ err *tcpip.Error
+}
+
+type ndpRouterEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ // true if router was discovered, false if invalidated.
+ discovered bool
+}
+
+var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+
+// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
+// related events happen for test purposes.
+type ndpDispatcher struct {
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ routeTable []tcpip.Route
+}
+
+// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ if n.dadC != nil {
+ n.dadC <- ndpDADEvent{
+ nicID,
+ addr,
+ resolved,
+ err,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) {
+ if n.routerC != nil {
+ n.routerC <- ndpRouterEvent{
+ nicID,
+ addr,
+ true,
+ }
+ }
+
+ if !n.rememberRouter {
+ return false, nil
+ }
+
+ rt := append([]tcpip.Route(nil), n.routeTable...)
+ rt = append(rt, tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr,
+ NIC: nicID,
+ })
+ n.routeTable = rt
+ return true, rt
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route {
+ if n.routerC != nil {
+ n.routerC <- ndpRouterEvent{
+ nicID,
+ addr,
+ false,
+ }
+ }
+
+ var rt []tcpip.Route
+ exclude := tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr,
+ NIC: nicID,
+ }
+
+ for _, r := range n.routeTable {
+ if r != exclude {
+ rt = append(rt, r)
+ }
+ }
+ n.routeTable = rt
+ return rt
+}
+
+// TestDADResolve tests that an address successfully resolves after performing
+// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
+// Included in the subtests is a test to make sure that an invalid
+// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+func TestDADResolve(t *testing.T) {
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {"1:1s:1s", 1, time.Second, time.Second},
+ {"2:1s:1s", 2, time.Second, time.Second},
+ {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ // 0s is an invalid RetransmitTimer timer and will be fixed to
+ // the default RetransmitTimer value of 1s.
+ {"1:0s:1s", 1, 0, time.Second},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = test.retransTimer
+ opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
+
+ // Should have sent an NDP NS immediately.
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got NeighborSolicit = %d, want = 1", got)
+
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for the remaining time - some delta (500ms), to
+ // make sure the address is still not resolved.
+ const delta = 500 * time.Millisecond
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+ }
+
+ // Should not have sent any more NS messages.
+ if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
+ }
+
+ // Validate the sent Neighbor Solicitation messages.
+ for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
+ p := <-e.C
+
+ // Make sure its an IPv6 packet.
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ // Check NDP packet.
+ checker.IPv6(t, p.Header.ToVectorisedView().First(),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNS(
+ checker.NDPNSTargetAddress(addr1)))
+ }
+ })
+ }
+
+}
+
+// TestDADFail tests to make sure that the DAD process fails if another node is
+// detected to be performing DAD on the same address (receive an NS message from
+// a node doing DAD for the same address), or if another node is detected to own
+// the address already (receive an NA message for the tentative address).
+func TestDADFail(t *testing.T) {
+ tests := []struct {
+ name string
+ makeBuf func(tgt tcpip.Address) buffer.Prependable
+ getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {
+ "RxSolicit",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(tgt)
+ snmc := header.SolicitedNodeAddr(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: header.IPv6Any,
+ DstAddr: snmc,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborSolicit
+ },
+ },
+ {
+ "RxAdvert",
+ func(tgt tcpip.Address) buffer.Prependable {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(pkt.NDPPayload())
+ na.SetSolicitedFlag(true)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(tgt)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: 255,
+ SrcAddr: tgt,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return hdr
+
+ },
+ func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return s.NeighborAdvert
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ }
+ opts.NDPConfigs.RetransmitTimer = time.Second * 2
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Receive a packet to simulate multiple nodes owning or
+ // attempting to own the same address.
+ hdr := test.makeBuf(addr1)
+ e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
+ if got := stat.Value(); got != 1 {
+ t.Fatalf("got stat = %d, want = 1", got)
+ }
+
+ // Wait for DAD to fail and make sure the address did
+ // not get resolved.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ })
+ }
+}
+
+// TestDADStop tests to make sure that the DAD process stops when an address is
+// removed.
+func TestDADStop(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Remove the address. This should stop DAD.
+ if err := s.RemoveAddress(1, addr1); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
+ }
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ }
+}
+
+// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NDP configurations using an invalid NICID.
+func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestSetNDPConfigurations tests that we can update and use per-interface NDP
+// configurations without affecting the default NDP configurations or other
+// interfaces' configurations.
+func TestSetNDPConfigurations(t *testing.T) {
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransmitTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {
+ "OK",
+ 1,
+ time.Second,
+ time.Second,
+ },
+ {
+ "Invalid Retransmit Timer",
+ 1,
+ 0,
+ time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ })
+
+ // This NIC(1)'s NDP configurations will be updated to
+ // be different from the default.
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Created before updating NIC(1)'s NDP configurations
+ // but updating NIC(1)'s NDP configurations should not
+ // affect other existing NICs.
+ if err := s.CreateNIC(2, e); err != nil {
+ t.Fatalf("CreateNIC(2) = %s", err)
+ }
+
+ // Update the NDP configurations on NIC(1) to use DAD.
+ configs := stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ }
+ if err := s.SetNDPConfigurations(1, configs); err != nil {
+ t.Fatalf("got SetNDPConfigurations(1, _) = %s", err)
+ }
+
+ // Created after updating NIC(1)'s NDP configurations
+ // but the stack's default NDP configurations should not
+ // have been updated.
+ if err := s.CreateNIC(3, e); err != nil {
+ t.Fatalf("CreateNIC(3) = %s", err)
+ }
+
+ // Add addresses for each NIC.
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+ if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err)
+ }
+ if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err)
+ }
+
+ // Address should not be considered bound to NIC(1) yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Should get the address on NIC(2) and NIC(3)
+ // immediately since we should not have performed DAD on
+ // it as the stack was configured to not do DAD by
+ // default and we only updated the NDP configurations on
+ // NIC(1).
+ addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err)
+ }
+ if addr.Address != addr2 {
+ t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2)
+ }
+ addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err)
+ }
+ if addr.Address != addr3 {
+ t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3)
+ }
+
+ // Sleep until right (500ms before) before resolution to
+ // make sure the address didn't resolve on NIC(1) yet.
+ const delta = 500 * time.Millisecond
+ time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1)
+ }
+ })
+ }
+}
+
+// raBuf returns a valid NDP Router Advertisement.
+//
+// Note, raBuf does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(0)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ iph.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+}
+
+// TestNoRouterDiscovery tests that router discovery will not be performed if
+// configured not to.
+func TestNoRouterDiscovery(t *testing.T) {
+ // Being configured to discover routers means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // router discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router when configured not to")
+ case <-time.After(defaultTimeout):
+ }
+ })
+ }
+}
+
+// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered router when the dispatcher asks it not to.
+func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ routeTable := []tcpip.Route{
+ {
+ header.IPv6EmptySubnet,
+ llAddr3,
+ 1,
+ },
+ }
+ s.SetRouteTable(routeTable)
+
+ // Rx an RA with short lifetime.
+ lifetime := time.Duration(1)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime)))
+ select {
+ case r := <-ndpDisp.routerC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != llAddr2 {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2)
+ }
+ if !r.discovered {
+ t.Fatal("got r.discovered = false, want = true")
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timeout waiting for router discovery event")
+ }
+
+ // Original route table should not have been modified.
+ if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable)
+ }
+
+ // Wait for the normal invalidation time plus an extra second to
+ // make sure we do not actually receive any invalidation events as
+ // we should not have remembered the router in the first place.
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have received any router events")
+ case <-time.After(lifetime*time.Second + defaultTimeout):
+ }
+
+ // Original route table should not have been modified.
+ if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable)
+ }
+}
+
+func TestRouterDiscovery(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 10),
+ rememberRouter: true,
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case r := <-ndpDisp.routerC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != addr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, addr)
+ }
+ if r.discovered != discovered {
+ t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timeout waiting for router discovery event")
+ }
+ }
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ case <-time.After(defaultTimeout):
+ }
+
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ waitForEvent(llAddr2, true, defaultTimeout)
+
+ // Should have a default route through the discovered router.
+ if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ }
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ l3Lifetime := time.Duration(6)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime)))
+ waitForEvent(llAddr3, true, defaultTimeout)
+
+ // Should have default routes through the discovered routers.
+ if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ }
+
+ // Rx an RA from lladdr2 with lesser lifetime.
+ l2Lifetime := time.Duration(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime)))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ case <-time.After(defaultTimeout):
+ }
+
+ // Should still have a default route through the discovered routers.
+ if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ }
+
+ // Wait for lladdr2's router invalidation timer to fire. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout)
+
+ // Should no longer have the default route through lladdr2.
+ if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ }
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ waitForEvent(llAddr2, true, defaultTimeout)
+
+ // Should have a default route through the discovered routers.
+ if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ }
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ waitForEvent(llAddr2, false, defaultTimeout)
+
+ // Should have deleted the default route through the router that just
+ // got invalidated.
+ if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, want)
+ }
+
+ // Wait for lladdr3's router invalidation timer to fire. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout)
+
+ // Should not have any routes now that all discovered routers have been
+ // invalidated.
+ if got := len(s.GetRouteTable()); got != 0 {
+ t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got)
+ }
+}
+
+// TestRouterDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 10),
+ rememberRouter: true,
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{}
+
+ // Receive an RA from 2 more than the max number of discovered routers.
+ for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ linkAddr := []byte{2, 2, 3, 4, 5, 0}
+ linkAddr[5] = byte(i)
+ llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+
+ if i <= stack.MaxDiscoveredDefaultRouters {
+ expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1}
+ select {
+ case r := <-ndpDisp.routerC:
+ if r.nicID != 1 {
+ t.Fatalf("got r.nicID = %d, want = 1", r.nicID)
+ }
+ if r.addr != llAddr {
+ t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr)
+ }
+ if !r.discovered {
+ t.Fatal("got r.discovered = false, want = true")
+ }
+ case <-time.After(defaultTimeout):
+ t.Fatal("timeout waiting for router discovery event")
+ }
+
+ } else {
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
+ case <-time.After(defaultTimeout):
+ }
+ }
+ }
+
+ // Should only have default routes for the first
+ // stack.MaxDiscoveredDefaultRouters discovered routers.
+ if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) {
+ t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt)
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 0e8a23f00..28a28ae6e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -19,7 +19,6 @@ import (
"sync"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/ilist"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -34,17 +33,24 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
mu sync.RWMutex
spoofing bool
promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
mcastJoins map[NetworkEndpointID]int32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
stats NICStats
+
+ // ndp is the NDP related state for NIC.
+ //
+ // Note, read and write operations on ndp require that the NIC is
+ // appropriately locked.
+ ndp ndpState
}
// NICStats includes transmitted and received stats.
@@ -78,17 +84,26 @@ const (
NeverPrimaryEndpoint
)
+// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
- return &NIC{
+ // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
+ // example, make sure that the link address it provides is a valid
+ // unicast ethernet address.
+
+ // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints
+ // observe an MTU of at least 1280 bytes. Ensure that this requirement
+ // of IPv6 is supported on this endpoint's LinkEndpoint.
+
+ nic := &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
- primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
+ packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
stats: NICStats{
Tx: DirectionStats{
Packets: &tcpip.StatCounter{},
@@ -99,7 +114,23 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
Bytes: &tcpip.StatCounter{},
},
},
+ ndp: ndpState{
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ },
+ }
+ nic.ndp.nic = nic
+
+ // Register supported packet endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.packetEPs[netProto] = []PacketEndpoint{}
}
+ for _, netProto := range stack.networkProtocols {
+ nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ }
+
+ return nic
}
// enable enables the NIC. enable will attach the link to its LinkEndpoint and
@@ -107,6 +138,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
func (n *NIC) enable() *tcpip.Error {
n.attachLinkEndpoint()
+ // Create an endpoint to receive broadcast packets on this interface.
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ if err := n.AddAddress(tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
+ }, NeverPrimaryEndpoint); err != nil {
+ return err
+ }
+ }
+
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
// and responds to the various NDP messages that are destined to the
@@ -114,11 +155,50 @@ func (n *NIC) enable() *tcpip.Error {
// when we perform Duplicate Address Detection, or Router Advertisement
// when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
// section 4.2 for more information.
- if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
- return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ //
+ // Also auto-generate an IPv6 link-local address based on the NIC's
+ // link address if it is configured to do so. Note, each interface is
+ // required to have IPv6 link-local unicast address, as per RFC 4291
+ // section 2.1.
+ _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]
+ if !ok {
+ return nil
}
- return nil
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
+ if !n.stack.autoGenIPv6LinkLocal {
+ return nil
+ }
+
+ l2addr := n.linkEP.LinkAddress()
+
+ // Only attempt to generate the link-local address if we have a
+ // valid MAC address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address
+ // (provided by LinkEndpoint.LinkAddress) before reaching this
+ // point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
+
+ addr := header.LinkLocalAddr(l2addr)
+
+ _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint)
+
+ return err
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -154,18 +234,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
n.mu.RLock()
defer n.mu.RUnlock()
- list := n.primary[protocol]
- if list == nil {
- return nil
- }
-
- for e := list.Front(); e != nil; e = e.Next() {
- r := e.(*referencedNetworkEndpoint)
- // TODO(crawshaw): allow broadcast address when SO_BROADCAST is set.
- switch r.ep.ID().LocalAddress {
- case header.IPv4Broadcast, header.IPv4Any:
- continue
- }
+ for _, r := range n.primary[protocol] {
if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
@@ -275,13 +344,35 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
switch ref.getKind() {
- case permanent:
+ case permanentTentative, permanent:
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent.
+ // Promote the endpoint to become permanent and respect
+ // the new peb.
if ref.tryIncRef() {
ref.setKind(permanent)
+
+ refs := n.primary[ref.protocol]
+ for i, r := range refs {
+ if r == ref {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ return ref, nil
+ case FirstPrimaryEndpoint:
+ if i == 0 {
+ return ref, nil
+ }
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ case NeverPrimaryEndpoint:
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ return ref, nil
+ }
+ }
+ }
+
+ n.insertPrimaryEndpointLocked(ref, peb)
+
return ref, nil
}
// tryIncRef failing means the endpoint is scheduled to be removed once
@@ -291,6 +382,7 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
n.removeEndpointLocked(ref)
}
}
+
return n.addAddressLocked(protocolAddress, peb, permanent)
}
@@ -314,6 +406,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
+
+ isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
+
+ // If the address is an IPv6 address and it is a permanent address,
+ // mark it as tentative so it goes through the DAD process.
+ if isIPv6Unicast && kind == permanent {
+ kind = permanentTentative
+ }
+
ref := &referencedNetworkEndpoint{
refs: 1,
ep: ep,
@@ -331,7 +432,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
// If we are adding an IPv6 unicast address, join the solicited-node
// multicast address.
- if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
return nil, err
@@ -340,17 +441,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.endpoints[id] = ref
- l, ok := n.primary[protocolAddress.Protocol]
- if !ok {
- l = &ilist.List{}
- n.primary[protocolAddress.Protocol] = l
- }
+ n.insertPrimaryEndpointLocked(ref, peb)
- switch peb {
- case CanBePrimaryEndpoint:
- l.PushBack(ref)
- case FirstPrimaryEndpoint:
- l.PushFront(ref)
+ // If we are adding a tentative IPv6 address, start DAD.
+ if isIPv6Unicast && kind == permanentTentative {
+ if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ return nil, err
+ }
}
return ref, nil
@@ -375,10 +472,12 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
- // Don't include expired or tempory endpoints to avoid confusion and
- // prevent the caller from using those.
+ // Don't include tentative, expired or temporary endpoints to
+ // avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentExpired, temporary:
+ case permanentTentative, permanentExpired, temporary:
+ // TODO(b/140898488): Should tentative addresses be
+ // returned?
continue
}
addrs = append(addrs, tcpip.ProtocolAddress{
@@ -399,12 +498,12 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
var addrs []tcpip.ProtocolAddress
for proto, list := range n.primary {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- // Don't include expired or tempory endpoints to avoid confusion and
- // prevent the caller from using those.
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints
+ // to avoid confusion and prevent the caller from using
+ // those.
switch ref.getKind() {
- case permanentExpired, temporary:
+ case permanentTentative, permanentExpired, temporary:
continue
}
@@ -464,6 +563,19 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
return append(sns, n.addressRanges...)
}
+// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
+// by peb.
+//
+// n MUST be locked.
+func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
+ switch peb {
+ case CanBePrimaryEndpoint:
+ n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ case FirstPrimaryEndpoint:
+ n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ }
+}
+
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
@@ -481,9 +593,12 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
}
delete(n.endpoints, id)
- wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
- if wasInList {
- n.primary[r.protocol].Remove(r)
+ refs := n.primary[r.protocol]
+ for i, ref := range refs {
+ if ref == r {
+ n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ break
+ }
}
r.ep.Close()
@@ -497,10 +612,22 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
r, ok := n.endpoints[NetworkEndpointID{addr}]
- if !ok || r.getKind() != permanent {
+ if !ok {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ kind := r.getKind()
+ if kind != permanent && kind != permanentTentative {
return tcpip.ErrBadLocalAddress
}
+ isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
+
+ // If we are removing a tentative IPv6 unicast address, stop DAD.
+ if isIPv6Unicast && kind == permanentTentative {
+ n.ndp.stopDuplicateAddressDetection(addr)
+ }
+
r.setKind(permanentExpired)
if !r.decRefLocked() {
// The endpoint still has references to it.
@@ -511,7 +638,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
- if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(addr)
if err := n.leaveGroupLocked(snmc); err != nil {
return err
@@ -541,6 +668,11 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
// exists yet. Otherwise it just increments its count. n MUST be locked before
// joinGroupLocked is called.
func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // TODO(b/143102137): When implementing MLD, make sure MLD packets are
+ // not sent unless a valid link-local address is available for use on n
+ // as an MLD packet's source address must be a link-local address as
+ // outlined in RFC 3810 section 5.
+
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -591,10 +723,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
- ref.ep.HandlePacket(&r, vv)
+ ref.ep.HandlePacket(&r, pkt)
ref.decRef()
}
@@ -604,9 +736,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
@@ -614,36 +746,39 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
return
}
+ // If no local link layer address is provided, assume it was sent
+ // directly to this NIC.
+ if local == "" {
+ local = n.linkEP.LinkAddress()
+ }
+
+ // Are any packet sockets listening for this network protocol?
+ n.mu.RLock()
+ packetEPs := n.packetEPs[protocol]
+ // Check whether there are packet sockets listening for every protocol.
+ // If we received a packet with protocol EthernetProtocolAll, then the
+ // previous for loop will have handled it.
+ if protocol != header.EthernetProtocolAll {
+ packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
+ }
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ ep.HandlePacket(n.id, local, protocol, pkt.Clone())
+ }
+
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
- if len(vv.First()) < netProto.MinimumPacketSize() {
+ if len(pkt.Data.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(vv.First())
-
- n.stack.AddLinkAddress(n.id, src, remote)
-
- // If the packet is destined to the IPv4 Broadcast address, then make a
- // route to each IPv4 network endpoint and let each endpoint handle the
- // packet.
- if dst == header.IPv4Broadcast {
- // n.endpoints is mutex protected so acquire lock.
- n.mu.RLock()
- for _, ref := range n.endpoints {
- if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
- }
- }
- n.mu.RUnlock()
- return
- }
+ src, dst := netProto.ParseAddresses(pkt.Data.First())
if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt)
return
}
@@ -671,31 +806,34 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
- ref.ep.HandlePacket(&r, vv)
+ ref.ep.HandlePacket(&r, pkt)
ref.decRef()
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
- hdr := buffer.NewPrependableFromView(vv.First())
- vv.RemoveFirst()
+ hdr := buffer.NewPrependableFromView(pkt.Data.First())
+ pkt.Data.RemoveFirst()
// TODO(b/128629022): use route.WritePacket.
- if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
+ if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, pkt.Data, protocol); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size()))
+ n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + pkt.Data.Size()))
}
}
return
}
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ // If a packet socket handled the packet, don't treat it as invalid.
+ if len(packetEPs) == 0 {
+ n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ }
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -707,46 +845,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, pkt)
- if len(vv.First()) < transProto.MinimumPacketSize() {
+ if len(pkt.Data.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
- if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, netHeader, vv) {
+ if state.defaultHandler(r, id, pkt) {
return
}
}
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) {
+ if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -757,20 +890,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- if len(vv.First()) < 8 {
+ if len(pkt.Data.First()) < 8 {
return
}
- srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First())
if err != nil {
return
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) {
return
}
}
@@ -785,14 +915,78 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// isAddrTentative returns true if addr is tentative on n.
+//
+// Note that if addr is not associated with n, then this function will return
+// false. It will only return true if the address is associated with the NIC
+// AND it is tentative.
+func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return false
+ }
+
+ return ref.getKind() == permanentTentative
+}
+
+// dupTentativeAddrDetected attempts to inform n that a tentative addr
+// is a duplicate on a link.
+//
+// dupTentativeAddrDetected will delete the tentative address if it exists.
+func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ return tcpip.ErrBadAddress
+ }
+
+ if ref.getKind() != permanentTentative {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return n.removePermanentAddressLocked(addr)
+}
+
+// setNDPConfigs sets the NDP configurations for n.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNDPConfigs(c NDPConfigurations) {
+ c.validate()
+
+ n.mu.Lock()
+ n.ndp.configs = c
+ n.mu.Unlock()
+}
+
+// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
+func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.ndp.handleRA(ip, ra)
+}
+
type networkEndpointKind int32
const (
+ // A permanentTentative endpoint is a permanent address that is not yet
+ // considered to be fully bound to an interface in the traditional
+ // sense. That is, the address is associated with a NIC, but packets
+ // destined to the address MUST NOT be accepted and MUST be silently
+ // dropped, and the address MUST NOT be used as a source address for
+ // outgoing packets. For IPv6, addresses will be of this kind until
+ // NDP's Duplicate Address Detection has resolved, or be deleted if
+ // the process results in detecting a duplicate address.
+ permanentTentative networkEndpointKind = iota
+
// A permanent endpoint is created by adding a permanent address (vs. a
// temporary one) to the NIC. Its reference count is biased by 1 to avoid
// removal when no route holds a reference to it. It is removed by explicitly
// removing the permanent address from the NIC.
- permanent networkEndpointKind = iota
+ permanent
// An expired permanent endoint is a permanent endoint that had its address
// removed from the NIC, and it is waiting to be removed once no more routes
@@ -810,8 +1004,37 @@ const (
temporary
)
+func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return tcpip.ErrNotSupported
+ }
+ n.packetEPs[netProto] = append(eps, ep)
+
+ return nil
+}
+
+func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ eps, ok := n.packetEPs[netProto]
+ if !ok {
+ return
+ }
+
+ for i, epOther := range eps {
+ if epOther == ep {
+ n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ return
+ }
+ }
+}
+
type referencedNetworkEndpoint struct {
- ilist.Entry
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 80101d4bb..c0026f5a3 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -60,24 +60,64 @@ const (
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
// HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+ // this transport endpoint. It sets pkt.TransportHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer)
- // HandleControlPacket is called by the stack when new control (e.g.,
+ // HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+ // HandleControlPacket takes ownership of pkt.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it. This cleanup may happen asynchronously. Wait can
+ // be used to block on this asynchronous cleanup.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // An endpoint can be requested to stop its worker goroutines by calling
+ // its Close method.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// RawTransportEndpoint is the interface that needs to be implemented by raw
// transport protocol endpoints. RawTransportEndpoints receive the entire
-// packet - including the link, network, and transport headers - as delivered
-// to netstack.
+// packet - including the network and transport headers - as delivered to
+// netstack.
type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
// layer up.
- HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+}
+
+// PacketEndpoint is the interface that needs to be implemented by packet
+// transport protocol endpoints. These endpoints receive link layer headers in
+// addition to whatever they contain (usually network and transport layer
+// headers and a payload).
+type PacketEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive that
+ // match the endpoint.
+ //
+ // Implementers should treat packet as immutable and should copy it
+ // before before modification.
+ //
+ // linkHeader may have a length of 0, in which case the PacketEndpoint
+ // should construct its own ethernet header for applications.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -107,7 +147,9 @@ type TransportProtocol interface {
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+ //
+ // HandleUnknownDestinationPacket takes ownership of pkt.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -125,13 +167,21 @@ type TransportProtocol interface {
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
- // transport protocol endpoint. It also returns the network layer
- // header for the enpoint to inspect or pass up the stack.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
+ // transport protocol endpoint.
+ //
+ // pkt.NetworkHeader must be set before calling DeliverTransportPacket.
+ //
+ // DeliverTransportPacket takes ownership of pkt.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
+ //
+ // pkt.NetworkHeader must be set before calling
+ // DeliverTransportControlPacket.
+ //
+ // DeliverTransportControlPacket takes ownership of pkt.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -146,6 +196,19 @@ const (
PacketLoop
)
+// NetworkHeaderParams are the header parameters given as input by the
+// transport endpoint to the network.
+type NetworkHeaderParams struct {
+ // Protocol refers to the transport protocol number.
+ Protocol tcpip.TransportProtocolNumber
+
+ // TTL refers to Time To Live field of the IP-header.
+ TTL uint8
+
+ // TOS refers to TypeOfService or TrafficClass field of the IP-header.
+ TOS uint8
+}
+
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
type NetworkEndpoint interface {
@@ -170,7 +233,11 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol.
- WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
+
+ // WritePackets writes packets to the given destination address and
+ // protocol.
+ WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
@@ -186,8 +253,10 @@ type NetworkEndpoint interface {
NICID() tcpip.NICID
// HandlePacket is called by the link layer when new packets arrive to
- // this network endpoint.
- HandlePacket(r *Route, vv buffer.VectorisedView)
+ // this network endpoint. It sets pkt.NetworkHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt tcpip.PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -212,7 +281,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -229,9 +298,15 @@ type NetworkProtocol interface {
// packets to the appropriate network endpoint after it has been handled by
// the data link layer.
type NetworkDispatcher interface {
- // DeliverNetworkPacket finds the appropriate network protocol
- // endpoint and hands the packet over for further processing.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // DeliverNetworkPacket finds the appropriate network protocol endpoint
+ // and hands the packet over for further processing.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverNetworkPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverNetworkPacket takes ownership of pkt.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -253,12 +328,18 @@ const (
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
- CapabilityGSO
+ CapabilityHardwareGSO
+
+ // CapabilitySoftwareGSO indicates the link endpoint supports of sending
+ // multiple packets using a single call (LinkEndpoint.WritePackets).
+ CapabilitySoftwareGSO
)
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint.
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the
+// stack.
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
@@ -288,6 +369,18 @@ type LinkEndpoint interface {
// r.LocalLinkAddress if it is provided.
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+ // WritePackets writes packets with the given protocol through the
+ // given route.
+ //
+ // Right now, WritePackets is used only when the software segmentation
+ // offload is enabled. If it will be used for something else, it may
+ // require to change syscall filters.
+ WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+
+ // WriteRawPacket writes a packet directly to the link. The packet
+ // should already have an ethernet header.
+ WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error
+
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
Attach(dispatcher NetworkDispatcher)
@@ -311,13 +404,14 @@ type LinkEndpoint interface {
type InjectableLinkEndpoint interface {
LinkEndpoint
- // Inject injects an inbound packet.
- Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ // InjectInbound injects an inbound packet.
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
- // WriteRawPacket writes a fully formed outbound packet directly to the link.
+ // InjectOutbound writes a fully formed outbound packet directly to the
+ // link.
//
// dest is used by endpoints with multiple raw destinations.
- WriteRawPacket(dest tcpip.Address, packet []byte) *tcpip.Error
+ InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
@@ -346,10 +440,10 @@ type LinkAddressResolver interface {
type LinkAddressCache interface {
// CheckLocalAddress determines if the given local address exists, and if it
// does not exist.
- CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+ CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
// AddLinkAddress adds a link address to the cache.
- AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+ AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
// GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
// If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
@@ -360,17 +454,22 @@ type LinkAddressCache interface {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
- GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
// RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// UnassociatedEndpointFactory produces endpoints for writing packets not
-// associated with a particular transport protocol. Such endpoints can be used
-// to write arbitrary packets that include the IP header.
-type UnassociatedEndpointFactory interface {
- NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+// RawFactory produces endpoints for writing various types of raw packets.
+type RawFactory interface {
+ // NewUnassociatedEndpoint produces endpoints for writing packets not
+ // associated with a particular transport protocol. Such endpoints can
+ // be used to write arbitrary packets that include the network header.
+ NewUnassociatedEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // NewPacketEndpoint produces endpoints for reading and writing packets
+ // that include network and (when cooked is false) link layer headers.
+ NewPacketEndpoint(stack *Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
// GSOType is the type of GSO segments.
@@ -381,8 +480,14 @@ type GSOType int
// Types of gso segments.
const (
GSONone GSOType = iota
+
+ // Hardware GSO types:
GSOTCPv4
GSOTCPv6
+
+ // GSOSW is used for software GSO segments which have to be sent by
+ // endpoint.WritePackets.
+ GSOSW
)
// GSO contains generic segmentation offload properties.
@@ -410,3 +515,7 @@ type GSOEndpoint interface {
// GSOMaxSize returns the maximum GSO packet size.
GSOMaxSize() uint32
}
+
+// SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment.
+// This isn't a hard limit, because it is never set into packet headers.
+const SoftwareGSOMaxSize = (1 << 16)
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 5c8b7977a..1a0a51b57 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -47,8 +47,8 @@ type Route struct {
// starts.
ref *referencedNetworkEndpoint
- // loop controls where WritePacket should send packets.
- loop PacketLooping
+ // Loop controls where WritePacket should send packets.
+ Loop PacketLooping
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -59,6 +59,8 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop = PacketLoop
} else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
}
return Route{
@@ -67,7 +69,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
- loop: loop,
+ Loop: loop,
}
}
@@ -152,12 +154,12 @@ func (r *Route) IsResolutionRequired() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
+ err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
@@ -167,6 +169,44 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
return err
}
+// PacketDescriptor is a packet descriptor which contains a packet header and
+// offset and size of packet data in a payload view.
+type PacketDescriptor struct {
+ Hdr buffer.Prependable
+ Off int
+ Size int
+}
+
+// NewPacketDescriptors allocates a set of packet descriptors.
+func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor {
+ buf := make([]byte, n*hdrSize)
+ hdrs := make([]PacketDescriptor, n)
+ for i := range hdrs {
+ hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
+ }
+ return hdrs
+}
+
+// WritePackets writes the set of packets through the given route.
+func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) {
+ if !r.ref.isValidForOutgoing() {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop)
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n))
+ }
+ r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
+ payloadSize := 0
+ for i := 0; i < n; i++ {
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength()))
+ payloadSize += hdrs[i].Size
+ }
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+ return n, err
+}
+
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
@@ -174,7 +214,7 @@ func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
@@ -208,10 +248,17 @@ func (r *Route) Clone() Route {
return *r
}
-// MakeLoopedRoute duplicates the given route and tweaks it in case of multicast.
+// MakeLoopedRoute duplicates the given route with special handling for routes
+// used for sending multicast or broadcast packets. In those cases the
+// multicast/broadcast address is the remote address when sending out, but for
+// incoming (looped) packets it becomes the local address. Similarly, the local
+// interface address that was the local address going out becomes the remote
+// address coming in. This is different to unicast routes where local and
+// remote addresses remain the same as they identify location (local vs remote)
+// not direction (source vs destination).
func (r *Route) MakeLoopedRoute() Route {
l := r.Clone()
- if header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
+ if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) {
l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress
l.RemoteLinkAddress = l.LocalLinkAddress
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 18d1704a5..2f8d8e822 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,10 +20,13 @@
package stack
import (
+ "encoding/binary"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -41,11 +44,14 @@ const (
resolutionTimeout = 1 * time.Second
// resolutionAttempts is set to the same ARP retries used in Linux.
resolutionAttempts = 3
+
+ // DefaultTOS is the default type of service value for network endpoints.
+ DefaultTOS = 0
)
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -339,6 +345,13 @@ type ResumableEndpoint interface {
Resume(*Stack)
}
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -346,10 +359,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
- // unassociatedFactory creates unassociated endpoints. If nil, raw
- // endpoints are disabled. It is set during Stack creation and is
- // immutable.
- unassociatedFactory UnassociatedEndpointFactory
+ // rawFactory creates raw endpoints. If nil, raw endpoints are
+ // disabled. It is set during Stack creation and is immutable.
+ rawFactory RawFactory
demux *transportDemuxer
@@ -357,9 +369,10 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -388,6 +401,32 @@ type Stack struct {
// icmpRateLimiter is a global rate limiter for all ICMP messages generated
// by the stack.
icmpRateLimiter *ICMPRateLimiter
+
+ // seed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ seed uint32
+
+ // ndpConfigs is the default NDP configurations used by interfaces.
+ ndpConfigs NDPConfigurations
+
+ // autoGenIPv6LinkLocal determines whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // See the AutoGenIPv6LinkLocal field of Options for more details.
+ autoGenIPv6LinkLocal bool
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
}
// Options contains optional Stack configuration.
@@ -411,14 +450,71 @@ type Options struct {
// stack (false).
HandleLocal bool
- // UnassociatedFactory produces unassociated endpoints raw endpoints.
- // Raw endpoints are enabled only if this is non-nil.
- UnassociatedFactory UnassociatedEndpointFactory
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
+
+ // NDPConfigs is the default NDP configurations used by interfaces.
+ //
+ // By default, NDPConfigs will have a zero value for its
+ // DupAddrDetectTransmits field, implying that DAD will not be performed
+ // before assigning an address to a NIC.
+ NDPConfigs NDPConfigurations
+
+ // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
+ // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // Note, setting this to true does not mean that a link-local address
+ // will be assigned right away, or at all. If Duplicate Address
+ // Detection is enabled, an address will only be assigned if it
+ // successfully resolves. If it fails, no further attempt will be made
+ // to auto-generate an IPv6 link-local address.
+ //
+ // The generated link-local address will follow RFC 4291 Appendix A
+ // guidelines.
+ AutoGenIPv6LinkLocal bool
+
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
+ // RawFactory produces raw endpoints. Raw endpoints are enabled only if
+ // this is non-nil.
+ RawFactory RawFactory
+}
+
+// TransportEndpointInfo holds useful information about a transport endpoint
+// which can be queried by monitoring tools.
+//
+// +stateify savable
+type TransportEndpointInfo struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+
+ NetProto tcpip.NetworkProtocolNumber
+ TransProto tcpip.TransportProtocolNumber
+
+ // The following fields are protected by endpoint mu.
+
+ ID TransportEndpointID
+ // BindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ BindNICID tcpip.NICID
+ BindAddr tcpip.Address
+ // RegisterNICID is the default NICID registered as a side-effect of
+ // connect or datagram write.
+ RegisterNICID tcpip.NICID
}
+// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
+// marker interface.
+func (*TransportEndpointInfo) IsEndpointInfo() {}
+
// New allocates a new networking stack with only the requested networking and
// transport protocols configured with default options.
//
+// Note, NDPConfigurations will be fixed before being used by the Stack. That
+// is, if an invalid value was provided, it will be reset to the default value.
+//
// Protocol options can be changed by calling the
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
@@ -429,17 +525,30 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
+ // Make sure opts.NDPConfigs contains valid values only.
+ opts.NDPConfigs.validate()
+
s := &Stack{
- transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
- networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
- linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
- nics: make(map[tcpip.NICID]*NIC),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
- PortManager: ports.NewPortManager(),
- clock: clock,
- stats: opts.Stats.FillIn(),
- handleLocal: opts.HandleLocal,
- icmpRateLimiter: NewICMPRateLimiter(),
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ clock: clock,
+ stats: opts.Stats.FillIn(),
+ handleLocal: opts.HandleLocal,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ seed: generateRandUint32(),
+ ndpConfigs: opts.NDPConfigs,
+ autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
+ ndpDisp: opts.NDPDisp,
}
// Add specified network protocols.
@@ -457,8 +566,8 @@ func New(opts Options) *Stack {
}
}
- // Add the factory for unassociated endpoints, if present.
- s.unassociatedFactory = opts.UnassociatedFactory
+ // Add the factory for raw endpoints, if present.
+ s.rawFactory = opts.RawFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -466,6 +575,11 @@ func New(opts Options) *Stack {
return s
}
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
@@ -527,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -593,12 +707,12 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if s.unassociatedFactory == nil {
+ if s.rawFactory == nil {
return nil, tcpip.ErrNotPermitted
}
if !associated {
- return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
+ return s.rawFactory.NewUnassociatedEndpoint(s, network, transport, waiterQueue)
}
t, ok := s.transportProtocols[transport]
@@ -609,6 +723,16 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return t.proto.NewRawEndpoint(s, network, waiterQueue)
}
+// NewPacketEndpoint creates a new packet endpoint listening for the given
+// netProto.
+func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if s.rawFactory == nil {
+ return nil, tcpip.ErrNotPermitted
+ }
+
+ return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
+}
+
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
@@ -893,7 +1017,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
} else {
for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !isBroadcast && !route.Destination.Contains(remoteAddr)) {
+ if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
if nic, ok := s.nics[route.NIC]; ok {
@@ -931,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool
// CheckLocalAddress determines if the given local address exists, and if it
// does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
-func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
s.mu.RLock()
defer s.mu.RUnlock()
// If a NIC is specified, we try to find the address there only.
- if nicid != 0 {
- nic := s.nics[nicid]
+ if nicID != 0 {
+ nic := s.nics[nicID]
if nic == nil {
return 0
}
@@ -996,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
}
// AddLinkAddress adds a link address to the stack link cache.
-func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
// TODO: provide a way for a transport endpoint to receive a signal
// that AddLinkAddress for a particular address has been called.
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
- nic := s.nics[nicid]
+ nic := s.nics[nicID]
if nic == nil {
s.mu.RUnlock()
return "", nil, tcpip.ErrUnknownNICID
}
s.mu.RUnlock()
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
// RemoveWaker implements LinkAddressCache.RemoveWaker.
-func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic := s.nics[nicid]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ if nic := s.nics[nicID]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
s.linkAddrCache.removeWaker(fullAddr, waker)
}
}
@@ -1033,73 +1157,52 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+}
- s.mu.RLock()
- defer s.mu.RUnlock()
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+ s.cleanupEndpoints[ep] = struct{}{}
+
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+ s.mu.Lock()
+ delete(s.cleanupEndpoints, ep)
+ s.mu.Unlock()
+}
+
+// FindTransportEndpoint finds an endpoint that most closely matches the provided
+// id. If no endpoint is found it returns nil.
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, r)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
@@ -1110,6 +1213,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
s.mu.Unlock()
}
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var es []TransportEndpoint
+ for _, e := range s.demux.protocol {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+ for e := range s.cleanupEndpoints {
+ es = append(es, e)
+ }
+ s.mu.Unlock()
+ return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+ s.mu.Lock()
+ for _, e := range es {
+ s.cleanupEndpoints[e] = struct{}{}
+ }
+ s.mu.Unlock()
+}
+
+// Close closes all currently registered transport endpoints.
+//
+// Endpoints created or modified during this call may not get closed.
+func (s *Stack) Close() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Close()
+ }
+}
+
+// Wait waits for all transport and link endpoints to halt their worker
+// goroutines.
+//
+// Endpoints created or modified during this call may not get waited on.
+//
+// Note that link endpoints must be stopped via an implementation specific
+// mechanism.
+func (s *Stack) Wait() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Wait()
+ }
+ for _, e := range s.CleanupEndpoints() {
+ e.Wait()
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, n := range s.nics {
+ n.linkEP.Wait()
+ }
+}
+
// Resume restarts the stack after a restore. This must be called after the
// entire system has been restored.
func (s *Stack) Resume() {
@@ -1124,6 +1290,109 @@ func (s *Stack) Resume() {
}
}
+// RegisterPacketEndpoint registers ep with the stack, causing it to receive
+// all traffic of the specified netProto on the given NIC. If nicID is 0, it
+// receives traffic from every NIC.
+func (s *Stack) RegisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If no NIC is specified, capture on all devices.
+ if nicID == 0 {
+ // Register with each NIC.
+ for _, nic := range s.nics {
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ s.unregisterPacketEndpointLocked(0, netProto, ep)
+ return err
+ }
+ }
+ return nil
+ }
+
+ // Capture on a specific device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ if err := nic.registerPacketEndpoint(netProto, ep); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnregisterPacketEndpoint unregisters ep for packets of the specified
+// netProto from the specified NIC. If nicID is 0, ep is unregistered from all
+// NICs.
+func (s *Stack) UnregisterPacketEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.unregisterPacketEndpointLocked(nicID, netProto, ep)
+}
+
+func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, ep PacketEndpoint) {
+ // If no NIC is specified, unregister on all devices.
+ if nicID == 0 {
+ // Unregister with each NIC.
+ for _, nic := range s.nics {
+ nic.unregisterPacketEndpoint(netProto, ep)
+ }
+ return
+ }
+
+ // Unregister in a single device.
+ nic, ok := s.nics[nicID]
+ if !ok {
+ return
+ }
+ nic.unregisterPacketEndpoint(netProto, ep)
+}
+
+// WritePacket writes data directly to the specified NIC. It adds an ethernet
+// header based on the arguments.
+func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicID]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ // Add our own fake ethernet header.
+ ethFields := header.EthernetFields{
+ SrcAddr: nic.linkEP.LinkAddress(),
+ DstAddr: dst,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ ethHeader := buffer.View(fakeHeader).ToVectorisedView()
+ ethHeader.Append(payload)
+
+ if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteRawPacket writes data directly to the specified NIC without adding any
+// headers.
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+ s.mu.Lock()
+ nic, ok := s.nics[nicID]
+ s.mu.Unlock()
+ if !ok {
+ return tcpip.ErrUnknownDevice
+ }
+
+ if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ return err
+ }
+
+ return nil
+}
+
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
@@ -1243,3 +1512,85 @@ func (s *Stack) SetICMPBurst(burst int) {
func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
+
+// IsAddrTentative returns true if addr is tentative on the NIC with ID id.
+//
+// Note that if addr is not associated with a NIC with id ID, then this
+// function will return false. It will only return true if the address is
+// associated with the NIC AND it is tentative.
+func (s *Stack) IsAddrTentative(id tcpip.NICID, addr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return false, tcpip.ErrUnknownNICID
+ }
+
+ return nic.isAddrTentative(addr), nil
+}
+
+// DupTentativeAddrDetected attempts to inform the NIC with ID id that a
+// tentative addr on it is a duplicate on a link.
+func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.dupTentativeAddrDetected(addr)
+}
+
+// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
+// with ID id to c.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setNDPConfigs(c)
+
+ return nil
+}
+
+// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
+// message that it needs to handle.
+func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.handleNDPRA(ip, ra)
+
+ return nil
+}
+
+// Seed returns a 32 bit value that can be used as a seed value for port
+// picking, ISN generation etc.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) Seed() uint32 {
+ return s.seed
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index d2dede8a9..bf1d6974c 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -24,11 +24,14 @@ import (
"sort"
"strings"
"testing"
+ "time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -55,7 +58,7 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
id stack.NetworkEndpointID
prefixLen int
proto *fakeNetworkProtocol
@@ -68,7 +71,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 {
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicid
+ return f.nicID
}
func (f *fakeNetworkEndpoint) PrefixLen() int {
@@ -83,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
- b := vv.First()
- vv.TrimFront(fakeNetHeaderLen)
+ b := pkt.Data.First()
+ pkt.Data.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
- nb := vv.First()
+ nb := pkt.Data.First()
if len(nb) < fakeNetHeaderLen {
return
}
- vv.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -119,7 +122,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -128,14 +131,16 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
b := hdr.Prepend(fakeNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
- b[2] = byte(protocol)
+ b[2] = byte(params.Protocol)
if loop&stack.PacketLoop != 0 {
views := make([]buffer.View, 1, 1+len(payload.Views()))
views[0] = hdr.View()
views = append(views, payload.Views()...)
vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
- f.HandlePacket(r, vv)
+ f.HandlePacket(r, tcpip.PacketBuffer{
+ Data: vv,
+ })
}
if loop&stack.PacketOut == 0 {
return nil
@@ -144,6 +149,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -189,9 +199,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
- nicid: nicid,
+ nicID: nicID,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
@@ -251,7 +261,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -261,7 +273,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -271,7 +285,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -280,7 +296,9 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -290,7 +308,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -310,7 +330,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS})
}
func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
@@ -365,7 +385,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got := fakeNet.PacketCount(localAddrByte); got != want {
t.Errorf("receive packet count: got = %d, want %d", got, want)
}
@@ -660,11 +682,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
}
-func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) {
t.Helper()
- info, ok := s.NICInfo()[nicid]
+ info, ok := s.NICInfo()[nicID]
if !ok {
- t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ t.Fatalf("NICInfo() failed to find nicID=%d", nicID)
}
if len(addr) == 0 {
// No address given, verify that there is no address assigned to the NIC.
@@ -697,7 +719,7 @@ func TestEndpointExpiration(t *testing.T) {
localAddrByte byte = 0x01
remoteAddr tcpip.Address = "\x03"
noAddr tcpip.Address = ""
- nicid tcpip.NICID = 1
+ nicID tcpip.NICID = 1
)
localAddr := tcpip.Address([]byte{localAddrByte})
@@ -709,7 +731,7 @@ func TestEndpointExpiration(t *testing.T) {
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -726,13 +748,13 @@ func TestEndpointExpiration(t *testing.T) {
buf[0] = localAddrByte
if promiscuous {
- if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
}
if spoofing {
- if err := s.SetSpoofing(nicid, true); err != nil {
+ if err := s.SetSpoofing(nicID, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
}
@@ -740,7 +762,7 @@ func TestEndpointExpiration(t *testing.T) {
// 1. No Address yet, send should only work for spoofing, receive for
// promiscuous mode.
//-----------------------
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -755,20 +777,20 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
// 3. Remove the address, send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -783,10 +805,10 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -804,10 +826,10 @@ func TestEndpointExpiration(t *testing.T) {
// 6. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -823,10 +845,10 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
testSend(t, r, ep, nil)
@@ -834,17 +856,17 @@ func TestEndpointExpiration(t *testing.T) {
// 8. Remove the route, sendTo/recv should still work.
//-----------------------
r.Release()
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
// 9. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -952,10 +974,10 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
testSendTo(t, s, dstAddr, ep, nil)
@@ -967,10 +989,10 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != localAddr {
- t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
}
// Sending a packet using the route works.
testSend(t, r, ep, nil)
@@ -1016,17 +1038,33 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("got Route.LocalAddress = %s, want = %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ t.Errorf("got Route.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
// testSendTo(t, s, remoteAddr, ep, nil)
}
-func TestBroadcastNeedsNoRoute(t *testing.T) {
+func verifyRoute(gotRoute, wantRoute stack.Route) error {
+ if gotRoute.LocalAddress != wantRoute.LocalAddress {
+ return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress)
+ }
+ if gotRoute.RemoteAddress != wantRoute.RemoteAddress {
+ return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress)
+ }
+ if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress {
+ return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress)
+ }
+ if gotRoute.NextHop != wantRoute.NextHop {
+ return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop)
+ }
+ return nil
+}
+
+func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
@@ -1039,28 +1077,99 @@ func TestBroadcastNeedsNoRoute(t *testing.T) {
// If there is no endpoint, it won't work.
if _, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(1, %v, %v, %v) = %v, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ t.Fatalf("got FindRoute(1, %s, %s, %d) = %s, want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
- if err := s.AddAddress(1, fakeNetNumber, header.IPv4Any); err != nil {
- t.Fatalf("AddAddress(%v, %v) failed: %s", fakeNetNumber, header.IPv4Any, err)
+ protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
+ if err := s.AddProtocolAddress(1, protoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %v, %v, %v) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // If the NIC doesn't exist, it won't work.
+ if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
+ t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ }
+}
+
+func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
+ defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0}
+ // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1.
+ nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24}
+ nic1Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1.
+ nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24}
+ nic2Gateway := tcpip.Address("\x0a\x0a\x0a\x01")
+
+ // Create a new stack with two NICs.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ if err := s.CreateNIC(2, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
+ }
+ nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
+ if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err)
}
- if r.LocalAddress != header.IPv4Any {
- t.Errorf("Bad local address: got %v, want = %v", r.LocalAddress, header.IPv4Any)
+ nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
+ if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
+ t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err)
}
- if r.RemoteAddress != header.IPv4Broadcast {
- t.Errorf("Bad remote address: got %v, want = %v", r.RemoteAddress, header.IPv4Broadcast)
+ // Set the initial route table.
+ rt := []tcpip.Route{
+ {Destination: nic1Addr.Subnet(), NIC: 1},
+ {Destination: nic2Addr.Subnet(), NIC: 2},
+ {Destination: defaultAddr.Subnet(), Gateway: nic2Gateway, NIC: 2},
+ {Destination: defaultAddr.Subnet(), Gateway: nic1Gateway, NIC: 1},
}
+ s.SetRouteTable(rt)
- // If the NIC doesn't exist, it won't work.
- if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %v, %v, %v) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ // When an interface is given, the route for a broadcast goes through it.
+ r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // When an interface is not given, it consults the route table.
+ // 1. Case: Using the default route.
+ r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+
+ // 2. Case: Having an explicit route for broadcast will select that one.
+ rt = append(
+ []tcpip.Route{
+ {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1},
+ },
+ rt...,
+ )
+ s.SetRouteTable(rt)
+ r, err = s.FindRoute(0, "", header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err)
+ }
+ if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
+ t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err)
}
}
@@ -1550,12 +1659,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
func TestAddAddress(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1563,7 +1672,7 @@ func TestAddAddress(t *testing.T) {
expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
for _, addrLen := range []int{4, 16} {
address := addrGen.next(addrLen)
- if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
@@ -1572,17 +1681,17 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1599,24 +1708,24 @@ func TestAddProtocolAddress(t *testing.T) {
PrefixLen: prefixLen,
},
}
- if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil {
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
}
expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1627,7 +1736,7 @@ func TestAddAddressWithOptions(t *testing.T) {
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil {
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
@@ -1637,17 +1746,17 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1666,7 +1775,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
PrefixLen: prefixLen,
},
}
- if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil {
+ if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
}
expectedAddresses = append(expectedAddresses, protocolAddress)
@@ -1674,7 +1783,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
@@ -1700,7 +1809,9 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1760,7 +1871,9 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
- ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
select {
case <-ep2.C:
@@ -1777,3 +1890,297 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
+
+// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
+// (or lack there-of if disabled (default)). Note, DAD will be disabled in
+// these tests.
+func TestNICAutoGenAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ shouldGen bool
+ }{
+ {
+ "Disabled",
+ false,
+ linkAddr1,
+ false,
+ },
+ {
+ "Enabled",
+ true,
+ linkAddr1,
+ true,
+ },
+ {
+ "Nil MAC",
+ true,
+ tcpip.LinkAddress([]byte(nil)),
+ false,
+ },
+ {
+ "Empty MAC",
+ true,
+ tcpip.LinkAddress(""),
+ false,
+ },
+ {
+ "Invalid MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03"),
+ false,
+ },
+ {
+ "Multicast MAC",
+ true,
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ false,
+ },
+ {
+ "Unspecified MAC",
+ true,
+ tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ }
+
+ if test.autoGen {
+ // Only set opts.AutoGenIPv6LinkLocal when
+ // test.autoGen is true because
+ // opts.AutoGenIPv6LinkLocal should be false by
+ // default.
+ opts.AutoGenIPv6LinkLocal = true
+ }
+
+ e := channel.New(10, 1280, test.linkAddr)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+
+ if test.shouldGen {
+ // Should have auto-generated an address and
+ // resolved immediately (DAD is disabled).
+ if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+ } else {
+ // Should not have auto-generated an address.
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+ }
+ })
+ }
+}
+
+// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
+// link-local addresses will only be assigned after the DAD process resolves.
+func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Address should not be considered bound to the
+ // NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicID != 1 {
+ t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
+ }
+ if e.addr != linkLocalAddr {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ }
+}
+
+// TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected
+// when an address's kind gets "promoted" to permanent from permanentExpired.
+func TestNewPEBOnPromotionToPermanent(t *testing.T) {
+ pebs := []stack.PrimaryEndpointBehavior{
+ stack.NeverPrimaryEndpoint,
+ stack.CanBePrimaryEndpoint,
+ stack.FirstPrimaryEndpoint,
+ }
+
+ for _, pi := range pebs {
+ for _, ps := range pebs {
+ t.Run(fmt.Sprintf("%d-to-%d", pi, ps), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ // Add a permanent address with initial
+ // PrimaryEndpointBehavior (peb), pi. If pi is
+ // NeverPrimaryEndpoint, the address should not
+ // be returned by a call to GetMainNICAddress;
+ // else, it should.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ addr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if pi == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatalf("NewSubnet failed:", err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // Take a route through the address so its ref
+ // count gets incremented and does not actually
+ // get deleted when RemoveAddress is called
+ // below. This is because we want to test that a
+ // new peb is respected when an address gets
+ // "promoted" to permanent from a
+ // permanentExpired kind.
+ r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+
+ //
+ // At this point, the address should still be
+ // known by the NIC, but have its
+ // kind = permanentExpired.
+ //
+
+ // Add some other address with peb set to
+ // FirstPrimaryEndpoint.
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+
+ }
+
+ // Add back the address we removed earlier and
+ // make sure the new peb was respected.
+ // (The address should just be promoted now).
+ if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
+ t.Fatal("AddAddressWithOptions failed:", err)
+ }
+ var primaryAddrs []tcpip.Address
+ for _, pa := range s.NICInfo()[1].ProtocolAddresses {
+ primaryAddrs = append(primaryAddrs, pa.AddressWithPrefix.Address)
+ }
+ var expectedList []tcpip.Address
+ switch ps {
+ case stack.FirstPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x01",
+ "\x03",
+ }
+ case stack.CanBePrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ "\x01",
+ }
+ case stack.NeverPrimaryEndpoint:
+ expectedList = []tcpip.Address{
+ "\x03",
+ }
+ }
+ if !cmp.Equal(primaryAddrs, expectedList) {
+ t.Fatalf("got NIC's primary addresses = %v, want = %v", primaryAddrs, expectedList)
+ }
+
+ // Once we remove the other address, if the new
+ // peb, ps, was NeverPrimaryEndpoint, no address
+ // should be returned by a call to
+ // GetMainNICAddress; else, our original address
+ // should be returned.
+ if err := s.RemoveAddress(1, "\x03"); err != nil {
+ t.Fatalf("RemoveAddress failed:", err)
+ }
+ addr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("s.GetMainNICAddress failed:", err)
+ }
+ if ps == stack.NeverPrimaryEndpoint {
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+
+ }
+ } else {
+ if addr.Address != "\x01" {
+ t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ }
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..cb805522b 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,10 +17,10 @@ package stack
import (
"fmt"
"math/rand"
+ "sort"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -35,7 +35,7 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
@@ -43,19 +43,122 @@ type transportEndpoints struct {
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
+ }
+ delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+ es := make([]TransportEndpoint, 0, len(eps.endpoints))
+ for _, e := range eps.endpoints {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+ var eps []TransportEndpoint
+ for _, ep := range epsByNic.endpoints {
+ eps = append(eps, ep.transportEndpoints()...)
+ }
+ return eps
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
return
}
}
- delete(eps.endpoints, id)
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if isMulticastOrBroadcast(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, pkt)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
@@ -75,7 +178,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +188,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +200,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
+}
+
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ return eps
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +234,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +246,77 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket takes ownership of pkt, so each endpoint needs
+ // its own copy except for the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, pkt)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ endpoint.HandlePacket(r, id, pkt.Clone())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
+// Close implements stack.TransportEndpoint.Close.
+func (ep *multiPortEndpoint) Close() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Close()
+ }
}
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// Wait implements stack.TransportEndpoint.Wait.
+func (ep *multiPortEndpoint) Wait() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Wait()
+ }
+}
+
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+
+ // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
+ // can be restored in the same order.
+ sort.Slice(ep.endpointsArr, func(i, j int) bool {
+ return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
+ })
+ for i, e := range ep.endpointsArr {
+ ep.endpointsMap[e] = i
+ }
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +341,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -259,30 +391,21 @@ var loopbackSubnet = func() tcpip.Subnet {
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
- // If a sender bound to the Loopback interface sends a broadcast,
- // that broadcast must not be delivered to the sender.
- if loopbackSubnet.Contains(r.RemoteAddress) && r.LocalAddress == header.IPv4Broadcast && id.LocalPort == id.RemotePort {
- return false
- }
-
- // If the packet is a broadcast, then find all matching transport endpoints.
- // Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
eps.mu.RLock()
- if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
- for epID, endpoint := range eps.endpoints {
- if epID.LocalPort == id.LocalPort {
- destEps = append(destEps, endpoint)
- }
- }
- } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
+ // Determine which transport endpoint or endpoints to deliver this packet to.
+ // If the packet is a broadcast or multicast, then find all matching
+ // transport endpoints.
+ var destEps []*endpointsByNic
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, id)
+ } else if ep := d.findEndpointLocked(eps, id); ep != nil {
destEps = append(destEps, ep)
}
@@ -297,17 +420,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
return false
}
- // Deliver the packet.
- for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ // HandlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEps[:len(destEps)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEps[len(destEps)-1].handlePacket(r, id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -321,7 +446,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ rawEP.HandlePacket(r, pkt)
foundRaw = true
}
eps.mu.RUnlock()
@@ -331,7 +456,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -339,7 +464,7 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
// Try to find the endpoint.
eps.mu.RLock()
- ep := d.findEndpointLocked(eps, vv, id)
+ ep := d.findEndpointLocked(eps, id)
eps.mu.RUnlock()
// Fail if we didn't find one.
@@ -348,15 +473,16 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, pkt)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the local address.
@@ -364,7 +490,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with the id minus the remote part.
@@ -372,15 +498,54 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- return ep
+ matchedEPs = append(matchedEPs, ep)
+ }
+ return matchedEPs
+}
+
+// findTransportEndpoint find a single endpoint that most closely matches the provided id.
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
+ }
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ epsByNic := d.findEndpointLocked(eps, id)
+ // Fail if we didn't find one.
+ if epsByNic == nil {
+ eps.mu.RUnlock()
+ return nil
+ }
+
+ epsByNic.mu.RLock()
+ eps.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return nil
+ }
}
+ ep := selectEndpoint(id, mpep, epsByNic.seed)
+ epsByNic.mu.RUnlock()
+ return ep
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given
+// id.
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
+ if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
+ return matchedEPs[0]
+ }
return nil
}
@@ -391,7 +556,7 @@ func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer
func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
- return nil
+ return tcpip.ErrNotSupported
}
eps.mu.Lock()
@@ -418,3 +583,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
}
}
+
+func isMulticastOrBroadcast(addr tcpip.Address) bool {
+ return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..3b28b06d0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,354 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicID := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 56e8a5d9b..2cacea99a 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -38,19 +38,27 @@ const (
// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
// use it.
type fakeTransportEndpoint struct {
- id stack.TransportEndpointID
+ stack.TransportEndpointInfo
stack *stack.Stack
- netProto tcpip.NetworkProtocolNumber
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+ uniqueID uint64
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
}
-func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto}
+func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
+ return &f.TransportEndpointInfo
+}
+
+func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
+ return nil
+}
+
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Close() {
@@ -75,7 +83,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil {
+ if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil {
return 0, nil, err
}
@@ -126,8 +134,8 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
// Try to register so that we can start receiving packets.
- f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ f.ID.RemoteAddress = addr.Addr
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -137,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -168,7 +180,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -184,14 +197,16 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- id: id,
- stack: f.stack,
- netProto: f.netProto,
+ stack: f.stack,
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
proto: f.proto,
peerAddr: r.RemoteAddress,
route: r.Clone(),
@@ -199,7 +214,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -208,15 +223,15 @@ func (f *fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
-}
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
return iptables.IPTables{}, nil
}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
-}
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+
+func (f *fakeTransportEndpoint) Wait() {}
type fakeTransportGoodOption bool
@@ -241,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -256,7 +271,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
@@ -327,7 +342,9 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -336,7 +353,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -345,7 +364,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
@@ -398,7 +419,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -407,7 +430,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -416,7 +441,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
if fakeTrans.controlCount != 1 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
}
@@ -569,7 +596,9 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ Data: req.ToVectorisedView(),
+ })
aep, _, err := ep.Accept()
if err != nil || aep == nil {