summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/nic.go69
-rw-r--r--pkg/tcpip/stack/stack.go57
-rw-r--r--pkg/tcpip/stack/stack_test.go97
3 files changed, 214 insertions, 9 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 77134c42a..7aa960096 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -82,6 +82,52 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
+func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var r *referencedNetworkEndpoint
+
+ // Check for a primary endpoint.
+ if list, ok := n.primary[protocol]; ok {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ if ref.holdsInsertRef && ref.tryIncRef() {
+ r = ref
+ break
+ }
+ }
+
+ }
+
+ // If no primary endpoints then check for other endpoints.
+ if r == nil {
+ for _, ref := range n.endpoints {
+ if ref.holdsInsertRef && ref.tryIncRef() {
+ r = ref
+ break
+ }
+ }
+ }
+
+ if r == nil {
+ return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress
+ }
+
+ address := r.ep.ID().LocalAddress
+ r.decRef()
+
+ // Find the least-constrained matching subnet for the address, if one
+ // exists, and return it.
+ var subnet tcpip.Subnet
+ for _, s := range n.subnets {
+ if s.Contains(address) && !subnet.Contains(s.ID()) {
+ subnet = s
+ }
+ }
+ return address, subnet, nil
+}
+
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -216,6 +262,29 @@ func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subne
n.mu.Unlock()
}
+// RemoveSubnet removes the given subnet from n.
+func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+ n.mu.Lock()
+
+ for i, sub := range n.subnets {
+ if sub == subnet {
+ n.subnets = append(n.subnets[:i], n.subnets[i+1:]...)
+ }
+ }
+
+ n.mu.Unlock()
+}
+
+// ContainsSubnet reports whether this NIC contains the given subnet.
+func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
+ for _, s := range n.Subnets() {
+ if s == subnet {
+ return true
+ }
+ }
+ return false
+}
+
// Subnets returns the Subnets associated with this NIC.
func (n *NIC) Subnets() []tcpip.Subnet {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 2c8c4aa31..675ccc6fa 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -623,13 +623,38 @@ func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber,
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
- return tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ nic.AddSubnet(protocol, subnet)
+ return nil
}
- nic.AddSubnet(protocol, subnet)
- return nil
+ return tcpip.ErrUnknownNICID
+}
+
+// RemoveSubnet removes the subnet range from the specified NIC.
+func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ nic.RemoveSubnet(subnet)
+ return nil
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// ContainsSubnet reports whether the specified NIC contains the specified
+// subnet.
+func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.ContainsSubnet(subnet), nil
+ }
+
+ return false, tcpip.ErrUnknownNICID
}
// RemoveAddress removes an existing network-layer address from the specified
@@ -638,12 +663,26 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
- return tcpip.ErrUnknownNICID
+ if nic, ok := s.nics[id]; ok {
+ return nic.RemoveAddress(addr)
+ }
+
+ return tcpip.ErrUnknownNICID
+}
+
+// GetMainNICAddress returns the first primary address (and the subnet that
+// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
+// address if no primary addresses exist. Returns an error if the NIC doesn't
+// exist or has no endpoints.
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[id]; ok {
+ return nic.getMainNICAddress(protocol)
}
- return nic.RemoveAddress(addr)
+ return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID
}
// FindRoute creates a route to the given destination address, leaving through
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 57de5b93a..c46e91241 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -19,6 +19,7 @@ package stack_test
import (
"math"
+ "strings"
"testing"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
@@ -763,6 +764,102 @@ func TestNetworkOptions(t *testing.T) {
}
}
+func TestSubnetAddRemove(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ addr := tcpip.Address("\x01\x01\x01\x01")
+ mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
+ subnet, err := tcpip.NewSubnet(addr, mask)
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil {
+ t.Fatalf("ContainsSubnet failed: %v", err)
+ } else if contained {
+ t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ }
+
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
+
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil {
+ t.Fatalf("ContainsSubnet failed: %v", err)
+ } else if !contained {
+ t.Fatal("got s.ContainsSubnet(...) = false, want = true")
+ }
+
+ if err := s.RemoveSubnet(1, subnet); err != nil {
+ t.Fatalf("RemoveSubnet failed: %v", err)
+ }
+
+ if contained, err := s.ContainsSubnet(1, subnet); err != nil {
+ t.Fatalf("ContainsSubnet failed: %v", err)
+ } else if contained {
+ t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ }
+}
+
+func TestGetMainNICAddress(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ address tcpip.Address
+ }{
+ {"IPv4", "\x01\x01\x01\x01"},
+ {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ address := tc.address
+ mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
+ subnet, err := tcpip.NewSubnet(address, mask)
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, address); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
+
+ // Check that we get the right initial address and subnet.
+ if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress failed: %v", err)
+ } else if gotAddress != address {
+ t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address)
+ } else if gotSubnet != subnet {
+ t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet)
+ }
+
+ if err := s.RemoveSubnet(1, subnet); err != nil {
+ t.Fatalf("RemoveSubnet failed: %v", err)
+ }
+
+ if err := s.RemoveAddress(1, address); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ // Check that we get an error after removal.
+ if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress)
+ }
+ })
+ }
+}
+
func init() {
stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
return &fakeNetworkProtocol{}