summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/ndp.go122
-rw-r--r--pkg/tcpip/stack/ndp_test.go315
-rw-r--r--pkg/tcpip/stack/nic.go30
-rw-r--r--pkg/tcpip/stack/registration.go33
-rw-r--r--pkg/tcpip/stack/stack.go151
-rw-r--r--pkg/tcpip/stack/stack_test.go46
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go95
-rw-r--r--pkg/tcpip/stack/transport_test.go19
8 files changed, 696 insertions, 115 deletions
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 921d1c9c7..03ddebdbd 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -51,6 +51,22 @@ const (
minimumRetransmitTimer = time.Millisecond
)
+// NDPDispatcher is the interface integrators of netstack must implement to
+// receive and handle NDP related events.
+type NDPDispatcher interface {
+ // OnDuplicateAddressDetectionStatus will be called when the DAD process
+ // for an address (addr) on a NIC (with ID nicid) completes. resolved
+ // will be set to true if DAD completed successfully (no duplicate addr
+ // detected); false otherwise (addr was detected to be a duplicate on
+ // the link the NIC is a part of, or it was stopped for some other
+ // reason, such as the address being removed). If an error occured
+ // during DAD, err will be set and resolved must be ignored.
+ //
+ // This function is permitted to block indefinitely without interfering
+ // with the stack's operation.
+ OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+}
+
// NDPConfigurations is the NDP configurations for the netstack.
type NDPConfigurations struct {
// The number of Neighbor Solicitation messages to send when doing
@@ -88,6 +104,12 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // The NIC this ndpState is for.
+ nic *NIC
+
+ // configs is the per-interface NDP configurations.
+ configs NDPConfigurations
+
// The DAD state to send the next NS message, or resolve the address.
dad map[tcpip.Address]dadState
}
@@ -110,8 +132,8 @@ type dadState struct {
// This function must only be called by IPv6 addresses that are currently
// tentative.
//
-// The NIC that ndp belongs to (n) MUST be locked.
-func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
// addr must be a valid unicast IPv6 address.
if !header.IsV6UnicastAddress(addr) {
return tcpip.ErrAddressFamilyNotSupported
@@ -127,13 +149,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address,
// reference count would have been increased without doing the
// work that would have been done for an address that was brand
// new. See NIC.addPermanentAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, n.ID()))
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
}
- remaining := n.stack.ndpConfigs.DupAddrDetectTransmits
+ remaining := ndp.configs.DupAddrDetectTransmits
{
- done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref)
+ done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
if err != nil {
return err
}
@@ -146,42 +168,59 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address,
var done bool
var timer *time.Timer
- timer = time.AfterFunc(n.stack.ndpConfigs.RetransmitTimer, func() {
- n.mu.Lock()
- defer n.mu.Unlock()
+ timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() {
+ var d bool
+ var err *tcpip.Error
+
+ // doDadIteration does a single iteration of the DAD loop.
+ //
+ // Returns true if the integrator needs to be informed of DAD
+ // completing.
+ doDadIteration := func() bool {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+
+ if done {
+ // If we reach this point, it means that the DAD
+ // timer fired after another goroutine already
+ // obtained the NIC lock and stopped DAD before
+ // this function obtained the NIC lock. Simply
+ // return here and do nothing further.
+ return false
+ }
- if done {
- // If we reach this point, it means that the DAD timer
- // fired after another goroutine already obtained the
- // NIC lock and stopped DAD before it this function
- // obtained the NIC lock. Simply return here and do
- // nothing further.
- return
- }
+ ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
+ if !ok {
+ // This should never happen.
+ // We should have an endpoint for addr since we
+ // are still performing DAD on it. If the
+ // endpoint does not exist, but we are doing DAD
+ // on it, then we started DAD at some point, but
+ // forgot to stop it when the endpoint was
+ // deleted.
+ panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
+ }
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
- if !ok {
- // This should never happen.
- // We should have an endpoint for addr since we are
- // still performing DAD on it. If the endpoint does not
- // exist, but we are doing DAD on it, then we started
- // DAD at some point, but forgot to stop it when the
- // endpoint was deleted.
- panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, n.ID()))
- }
+ d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
+ if err != nil || d {
+ delete(ndp.dad, addr)
- if done, err := ndp.doDuplicateAddressDetection(n, addr, remaining, ref); err != nil || done {
- if err != nil {
- log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, n.ID(), err)
+ if err != nil {
+ log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
+ }
+
+ // Let the integrator know DAD has completed.
+ return true
}
- ndp.stopDuplicateAddressDetection(addr)
- return
+ remaining--
+ timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
+ return false
}
- timer.Reset(n.stack.ndpConfigs.RetransmitTimer)
- remaining--
-
+ if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
+ ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ }
})
ndp.dad[addr] = dadState{
@@ -204,11 +243,11 @@ func (ndp *ndpState) startDuplicateAddressDetection(n *NIC, addr tcpip.Address,
// The NIC that ndp belongs to (n) MUST be locked.
//
// Returns true if DAD has resolved; false if DAD is still ongoing.
-func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
+func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
if ref.getKind() != permanentTentative {
// The endpoint should still be marked as tentative
// since we are still performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, n.ID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
}
if remaining == 0 {
@@ -219,17 +258,17 @@ func (ndp *ndpState) doDuplicateAddressDetection(n *NIC, addr tcpip.Address, rem
// Send a new NS.
snmc := header.SolicitedNodeAddr(addr)
- snmcRef, ok := n.endpoints[NetworkEndpointID{snmc}]
+ snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
if !ok {
// This should never happen as if we have the
// address, we should have the solicited-node
// address.
- panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", n.ID(), snmc, addr))
+ panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
}
// Use the unspecified address as the source address when performing
// DAD.
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, n.linkEP.LinkAddress(), snmcRef, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
@@ -275,5 +314,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
- return
+ // Let the integrator know DAD did not resolve.
+ if ndp.nic.stack.ndpDisp != nil {
+ go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 8995fbfc3..525a25218 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -31,6 +31,7 @@ import (
const (
addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
linkAddr1 = "\x02\x02\x03\x04\x05\x06"
)
@@ -67,6 +68,35 @@ func TestDADDisabled(t *testing.T) {
}
}
+// ndpDADEvent is a set of parameters that was passed to
+// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+type ndpDADEvent struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ resolved bool
+ err *tcpip.Error
+}
+
+var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
+
+// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
+// related events happen for test purposes.
+type ndpDispatcher struct {
+ dadC chan ndpDADEvent
+}
+
+// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
+//
+// If the DAD event matches what we are expecting, send signal on n.dadC.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ n.dadC <- ndpDADEvent{
+ nicid,
+ addr,
+ resolved,
+ err,
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -88,8 +118,12 @@ func TestDADResolve(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
}
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
@@ -106,8 +140,7 @@ func TestDADResolve(t *testing.T) {
stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
- // Should have sent an NDP NS almost immediately.
- time.Sleep(100 * time.Millisecond)
+ // Should have sent an NDP NS immediately.
if got := stat.Value(); got != 1 {
t.Fatalf("got NeighborSolicit = %d, want = 1", got)
@@ -123,16 +156,10 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
}
- // Wait for the remaining time - 500ms, to make sure
- // the address is still not resolved. Note, we subtract
- // 600ms because we already waited for 100ms earlier,
- // so our remaining time is 100ms less than the expected
- // time.
- // (X - 100ms) - 500ms = X - 600ms
- //
- // TODO(b/140896005): Use events from the netstack to
- // be signalled before DAD resolves.
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - 600*time.Millisecond)
+ // Wait for the remaining time - some delta (500ms), to
+ // make sure the address is still not resolved.
+ const delta = 500 * time.Millisecond
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
@@ -141,13 +168,30 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
}
- // Wait for the remaining time + 250ms, at which point
- // the address should be resolved. Note, the remaining
- // time is 500ms. See above comments.
- //
- // TODO(b/140896005): Use events from the netstack to
- // know immediately when DAD completes.
- time.Sleep(750 * time.Millisecond)
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
@@ -250,9 +294,14 @@ func TestDADFail(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.DefaultNDPConfigurations(),
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
}
opts.NDPConfigs.RetransmitTimer = time.Second * 2
@@ -286,8 +335,28 @@ func TestDADFail(t *testing.T) {
t.Fatalf("got stat = %d, want = 1", got)
}
- // Wait 3 seconds to make sure that DAD did not resolve
- time.Sleep(3 * time.Second)
+ // Wait for DAD to fail and make sure the address did
+ // not get resolved.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the
+ // expected resolution time + extra 1s buffer,
+ // something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
@@ -302,11 +371,18 @@ func TestDADFail(t *testing.T) {
// TestDADStop tests to make sure that the DAD process stops when an address is
// removed.
func TestDADStop(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
}
- opts.NDPConfigs.RetransmitTimer = time.Second
- opts.NDPConfigs.DupAddrDetectTransmits = 2
e := channel.New(10, 1280, linkAddr1)
s := stack.New(opts)
@@ -332,11 +408,27 @@ func TestDADStop(t *testing.T) {
t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
}
- // Wait for the time to normally resolve
- // DupAddrDetectTransmits(2) * RetransmitTimer(1s) = 2s.
- // An extra 250ms is added to make sure that if DAD was still running
- // it resolves and the check below fails.
- time.Sleep(2*time.Second + 250*time.Millisecond)
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if e.resolved {
+ t.Fatal("got DAD event w/ resolved = true, want = false")
+ }
+
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
@@ -350,3 +442,168 @@ func TestDADStop(t *testing.T) {
t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
}
}
+
+// TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NDP configurations using an invalid NICID.
+func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ if got := s.SetNDPConfigurations(1, stack.NDPConfigurations{}); got != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations = %v, want = %s", got, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestSetNDPConfigurations tests that we can update and use per-interface NDP
+// configurations without affecting the default NDP configurations or other
+// interfaces' configurations.
+func TestSetNDPConfigurations(t *testing.T) {
+ tests := []struct {
+ name string
+ dupAddrDetectTransmits uint8
+ retransmitTimer time.Duration
+ expectedRetransmitTimer time.Duration
+ }{
+ {
+ "OK",
+ 1,
+ time.Second,
+ time.Second,
+ },
+ {
+ "Invalid Retransmit Timer",
+ 1,
+ 0,
+ time.Second,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ })
+
+ // This NIC(1)'s NDP configurations will be updated to
+ // be different from the default.
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Created before updating NIC(1)'s NDP configurations
+ // but updating NIC(1)'s NDP configurations should not
+ // affect other existing NICs.
+ if err := s.CreateNIC(2, e); err != nil {
+ t.Fatalf("CreateNIC(2) = %s", err)
+ }
+
+ // Update the NDP configurations on NIC(1) to use DAD.
+ configs := stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrDetectTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ }
+ if err := s.SetNDPConfigurations(1, configs); err != nil {
+ t.Fatalf("got SetNDPConfigurations(1, _) = %s", err)
+ }
+
+ // Created after updating NIC(1)'s NDP configurations
+ // but the stack's default NDP configurations should not
+ // have been updated.
+ if err := s.CreateNIC(3, e); err != nil {
+ t.Fatalf("CreateNIC(3) = %s", err)
+ }
+
+ // Add addresses for each NIC.
+ if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ }
+ if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err)
+ }
+ if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err)
+ }
+
+ // Address should not be considered bound to NIC(1) yet
+ // (DAD ongoing).
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Should get the address on NIC(2) and NIC(3)
+ // immediately since we should not have performed DAD on
+ // it as the stack was configured to not do DAD by
+ // default and we only updated the NDP configurations on
+ // NIC(1).
+ addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err)
+ }
+ if addr.Address != addr2 {
+ t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2)
+ }
+ addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err)
+ }
+ if addr.Address != addr3 {
+ t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3)
+ }
+
+ // Sleep until right (500ms before) before resolution to
+ // make sure the address didn't resolve on NIC(1) yet.
+ const delta = 500 * time.Millisecond
+ time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(2 * delta):
+ // We should get a resolution event after 500ms
+ // (delta) since we wait for 500ms less than the
+ // expected resolution time above to make sure
+ // that the address did not yet resolve. Waiting
+ // for 1s (2x delta) without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != addr1 {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
+ addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index e456e05f4..fe8f83d58 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -46,6 +46,10 @@ type NIC struct {
stats NICStats
+ // ndp is the NDP related state for NIC.
+ //
+ // Note, read and write operations on ndp require that the NIC is
+ // appropriately locked.
ndp ndpState
}
@@ -80,10 +84,16 @@ const (
NeverPrimaryEndpoint
)
+// newNIC returns a new NIC using the default NDP configurations from stack.
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
+
+ // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints
+ // observe an MTU of at least 1280 bytes. Ensure that this requirement
+ // of IPv6 is supported on this endpoint's LinkEndpoint.
+
nic := &NIC{
stack: stack,
id: id,
@@ -105,9 +115,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
},
},
ndp: ndpState{
- dad: make(map[tcpip.Address]dadState),
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
},
}
+ nic.ndp.nic = nic
// Register supported packet endpoint protocols.
for _, netProto := range header.Ethertypes {
@@ -432,7 +444,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
// If we are adding a tentative IPv6 address, start DAD.
if isIPv6Unicast && kind == permanentTentative {
- if err := n.ndp.startDuplicateAddressDetection(n, protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
}
@@ -750,7 +762,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, vv, linkHeader)
+ ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
@@ -936,6 +948,18 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return n.removePermanentAddressLocked(addr)
}
+// setNDPConfigs sets the NDP configurations for n.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNDPConfigs(c NDPConfigurations) {
+ c.validate()
+
+ n.mu.Lock()
+ n.ndp.configs = c
+ n.mu.Unlock()
+}
+
type networkEndpointKind int32
const (
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0869fb084..d7c124e81 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -60,13 +60,34 @@ const (
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
+ //
+ // HandlePacket takes ownership of vv.
HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
+ //
+ // HandleControlPacket takes ownership of vv.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it. This cleanup may happen asynchronously. Wait can
+ // be used to block on this asynchronous cleanup.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // An endpoint can be requested to stop its worker goroutines by calling
+ // its Close method.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// RawTransportEndpoint is the interface that needs to be implemented by raw
@@ -77,6 +98,8 @@ type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
// layer up.
+ //
+ // HandlePacket takes ownership of packet and netHeader.
HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
}
@@ -93,6 +116,8 @@ type PacketEndpoint interface {
//
// linkHeader may have a length of 0, in which case the PacketEndpoint
// should construct its own ethernet header for applications.
+ //
+ // HandlePacket takes ownership of packet and linkHeader.
HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View)
}
@@ -143,10 +168,14 @@ type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
// transport protocol endpoint. It also returns the network layer
// header for the enpoint to inspect or pass up the stack.
+ //
+ // DeliverTransportPacket takes ownership of vv and netHeader.
DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
+ //
+ // DeliverTransportControlPacket takes ownership of vv.
DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
@@ -220,6 +249,8 @@ type NetworkEndpoint interface {
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
+ //
+ // HandlePacket takes ownership of vv.
HandlePacket(r *Route, vv buffer.VectorisedView)
// Close is called when the endpoint is reomved from a stack.
@@ -265,6 +296,8 @@ type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing. linkHeader may have
// length 0 when the caller does not have ethernet data.
+ //
+ // DeliverNetworkPacket takes ownership of vv and linkHeader.
DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View)
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 284280917..115a6fcb8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"encoding/binary"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -344,6 +345,13 @@ type ResumableEndpoint interface {
Resume(*Stack)
}
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -361,9 +369,10 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+ forwarding bool
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -399,13 +408,25 @@ type Stack struct {
// TODO(gvisor.dev/issue/940): S/R this field.
portSeed uint32
- // ndpConfigs is the NDP configurations used by interfaces.
+ // ndpConfigs is the default NDP configurations used by interfaces.
ndpConfigs NDPConfigurations
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
// to auto-generate an IPv6 link-local address for newly enabled NICs.
// See the AutoGenIPv6LinkLocal field of Options for more details.
autoGenIPv6LinkLocal bool
+
+ // ndpDisp is the NDP event dispatcher that is used to send the netstack
+ // integrator NDP related events.
+ ndpDisp NDPDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
}
// Options contains optional Stack configuration.
@@ -429,7 +450,10 @@ type Options struct {
// stack (false).
HandleLocal bool
- // NDPConfigs is the NDP configurations used by interfaces.
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
+
+ // NDPConfigs is the default NDP configurations used by interfaces.
//
// By default, NDPConfigs will have a zero value for its
// DupAddrDetectTransmits field, implying that DAD will not be performed
@@ -448,6 +472,10 @@ type Options struct {
// guidelines.
AutoGenIPv6LinkLocal bool
+ // NDPDisp is the NDP event dispatcher that an integrator can provide to
+ // receive NDP related events.
+ NDPDisp NDPDispatcher
+
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
@@ -497,6 +525,10 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
@@ -505,6 +537,7 @@ func New(opts Options) *Stack {
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
@@ -514,6 +547,8 @@ func New(opts Options) *Stack {
portSeed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
+ ndpDisp: opts.NDPDisp,
}
// Add specified network protocols.
@@ -540,6 +575,11 @@ func New(opts Options) *Stack {
return s
}
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
@@ -1127,6 +1167,25 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.cleanupEndpoints[ep] = struct{}{}
+
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+ s.mu.Lock()
+ delete(s.cleanupEndpoints, ep)
+ s.mu.Unlock()
+}
+
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
@@ -1148,6 +1207,69 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
s.mu.Unlock()
}
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var es []TransportEndpoint
+ for _, e := range s.demux.protocol {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+ for e := range s.cleanupEndpoints {
+ es = append(es, e)
+ }
+ s.mu.Unlock()
+ return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+ s.mu.Lock()
+ for _, e := range es {
+ s.cleanupEndpoints[e] = struct{}{}
+ }
+ s.mu.Unlock()
+}
+
+// Close closes all currently registered transport endpoints.
+//
+// Endpoints created or modified during this call may not get closed.
+func (s *Stack) Close() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Close()
+ }
+}
+
+// Wait waits for all transport and link endpoints to halt their worker
+// goroutines.
+//
+// Endpoints created or modified during this call may not get waited on.
+//
+// Note that link endpoints must be stopped via an implementation specific
+// mechanism.
+func (s *Stack) Wait() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Wait()
+ }
+ for _, e := range s.CleanupEndpoints() {
+ e.Wait()
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, n := range s.nics {
+ n.linkEP.Wait()
+ }
+}
+
// Resume restarts the stack after a restore. This must be called after the
// entire system has been restored.
func (s *Stack) Resume() {
@@ -1416,6 +1538,25 @@ func (s *Stack) DupTentativeAddrDetected(id tcpip.NICID, addr tcpip.Address) *tc
return nic.dupTentativeAddrDetected(addr)
}
+// SetNDPConfigurations sets the per-interface NDP configurations on the NIC
+// with ID id to c.
+//
+// Note, if c contains invalid NDP configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setNDPConfigs(c)
+
+ return nil
+}
+
// PortSeed returns a 32 bit value that can be used as a seed value for port
// picking.
//
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9a8906a0d..9dae853d0 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1971,13 +1971,15 @@ func TestNICAutoGenAddr(t *testing.T) {
// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
// link-local addresses will only be assigned after the DAD process resolves.
func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 1,
- },
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
AutoGenIPv6LinkLocal: true,
+ NDPDisp: &ndpDisp,
}
e := channel.New(10, 1280, linkAddr1)
@@ -1996,21 +1998,35 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
}
- // Wait for the address to resolve (an extra
- // 250ms to make sure the address resolves).
- //
- // TODO(b/140896005): Use events from the
- // netstack to know immediately when DAD
- // completes.
- time.Sleep(time.Second + 250*time.Millisecond)
+ linkLocalAddr := header.LinkLocalAddr(linkAddr1)
- // Should have auto-generated an address and
- // resolved (if DAD).
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // We should get a resolution event after 1s (default time to
+ // resolve as per default NDP configurations). Waiting for that
+ // resolution time + an extra 1s without a resolution event
+ // means something is wrong.
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if e.err != nil {
+ t.Fatal("got DAD error: ", e.err)
+ }
+ if e.nicid != 1 {
+ t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ if e.addr != linkLocalAddr {
+ t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
+ }
+ if !e.resolved {
+ t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ }
addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
if err != nil {
t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
}
- if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(linkAddr1), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
+ if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 97a1aec4b..ccd3d030e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"math/rand"
+ "sort"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -41,6 +42,31 @@ type transportEndpoints struct {
rawEndpoints []RawTransportEndpoint
}
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ epsByNic, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
+ }
+ delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+ es := make([]TransportEndpoint, 0, len(eps.endpoints))
+ for _, e := range eps.endpoints {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
type endpointsByNic struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
@@ -48,6 +74,16 @@ type endpointsByNic struct {
seed uint32
}
+func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+ var eps []TransportEndpoint
+ for _, ep := range epsByNic.endpoints {
+ eps = append(eps, ep.transportEndpoints()...)
+ }
+ return eps
+}
+
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
@@ -127,21 +163,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
return len(epsByNic.endpoints) == 0
}
-// unregisterEndpoint unregisters the endpoint with the given id such that it
-// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- eps.mu.Lock()
- defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
- if !ok {
- return
- }
- if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
- return
- }
- delete(eps.endpoints, id)
-}
-
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
@@ -183,14 +204,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex
+ mu sync.RWMutex `state:"nosave"`
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
reuse bool
}
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ return eps
+}
+
// reciprocalScale scales a value into range [0, n).
//
// This is similar to val % n, but faster.
@@ -240,6 +274,26 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, v
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
+// Close implements stack.TransportEndpoint.Close.
+func (ep *multiPortEndpoint) Close() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Close()
+ }
+}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (ep *multiPortEndpoint) Wait() {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
+ ep.mu.RUnlock()
+ for _, e := range eps {
+ e.Wait()
+ }
+}
+
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
@@ -257,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
// endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+
+ // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
+ // can be restored in the same order.
+ sort.Slice(ep.endpointsArr, func(i, j int) bool {
+ return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
+ })
+ for i, e := range ep.endpointsArr {
+ ep.endpointsMap[e] = i
+ }
return nil
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 86c62be25..203e79f56 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -43,6 +43,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+ uniqueID uint64
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
@@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Close() {
@@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -218,15 +223,15 @@ func (f *fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
-}
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
return iptables.IPTables{}, nil
}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
-}
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+
+func (f *fakeTransportEndpoint) Wait() {}
type fakeTransportGoodOption bool
@@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {