diff options
-rw-r--r-- | pkg/tcpip/stack/nic.go | 69 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 97 |
3 files changed, 214 insertions, 9 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 77134c42a..7aa960096 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -82,6 +82,52 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } +func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) { + n.mu.RLock() + defer n.mu.RUnlock() + + var r *referencedNetworkEndpoint + + // Check for a primary endpoint. + if list, ok := n.primary[protocol]; ok { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + if ref.holdsInsertRef && ref.tryIncRef() { + r = ref + break + } + } + + } + + // If no primary endpoints then check for other endpoints. + if r == nil { + for _, ref := range n.endpoints { + if ref.holdsInsertRef && ref.tryIncRef() { + r = ref + break + } + } + } + + if r == nil { + return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress + } + + address := r.ep.ID().LocalAddress + r.decRef() + + // Find the least-constrained matching subnet for the address, if one + // exists, and return it. + var subnet tcpip.Subnet + for _, s := range n.subnets { + if s.Contains(address) && !subnet.Contains(s.ID()) { + subnet = s + } + } + return address, subnet, nil +} + // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -216,6 +262,29 @@ func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subne n.mu.Unlock() } +// RemoveSubnet removes the given subnet from n. +func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { + n.mu.Lock() + + for i, sub := range n.subnets { + if sub == subnet { + n.subnets = append(n.subnets[:i], n.subnets[i+1:]...) + } + } + + n.mu.Unlock() +} + +// ContainsSubnet reports whether this NIC contains the given subnet. +func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { + for _, s := range n.Subnets() { + if s == subnet { + return true + } + } + return false +} + // Subnets returns the Subnets associated with this NIC. func (n *NIC) Subnets() []tcpip.Subnet { n.mu.RLock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2c8c4aa31..675ccc6fa 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -623,13 +623,38 @@ func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { - return tcpip.ErrUnknownNICID + if nic, ok := s.nics[id]; ok { + nic.AddSubnet(protocol, subnet) + return nil } - nic.AddSubnet(protocol, subnet) - return nil + return tcpip.ErrUnknownNICID +} + +// RemoveSubnet removes the subnet range from the specified NIC. +func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + nic.RemoveSubnet(subnet) + return nil + } + + return tcpip.ErrUnknownNICID +} + +// ContainsSubnet reports whether the specified NIC contains the specified +// subnet. +func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.ContainsSubnet(subnet), nil + } + + return false, tcpip.ErrUnknownNICID } // RemoveAddress removes an existing network-layer address from the specified @@ -638,12 +663,26 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { - return tcpip.ErrUnknownNICID + if nic, ok := s.nics[id]; ok { + return nic.RemoveAddress(addr) + } + + return tcpip.ErrUnknownNICID +} + +// GetMainNICAddress returns the first primary address (and the subnet that +// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's +// address if no primary addresses exist. Returns an error if the NIC doesn't +// exist or has no endpoints. +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.getMainNICAddress(protocol) } - return nic.RemoveAddress(addr) + return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID } // FindRoute creates a route to the given destination address, leaving through diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 57de5b93a..c46e91241 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -19,6 +19,7 @@ package stack_test import ( "math" + "strings" "testing" "gvisor.googlesource.com/gvisor/pkg/tcpip" @@ -763,6 +764,102 @@ func TestNetworkOptions(t *testing.T) { } } +func TestSubnetAddRemove(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + addr := tcpip.Address("\x01\x01\x01\x01") + mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) + subnet, err := tcpip.NewSubnet(addr, mask) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + + if contained, err := s.ContainsSubnet(1, subnet); err != nil { + t.Fatalf("ContainsSubnet failed: %v", err) + } else if contained { + t.Fatal("got s.ContainsSubnet(...) = true, want = false") + } + + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + + if contained, err := s.ContainsSubnet(1, subnet); err != nil { + t.Fatalf("ContainsSubnet failed: %v", err) + } else if !contained { + t.Fatal("got s.ContainsSubnet(...) = false, want = true") + } + + if err := s.RemoveSubnet(1, subnet); err != nil { + t.Fatalf("RemoveSubnet failed: %v", err) + } + + if contained, err := s.ContainsSubnet(1, subnet); err != nil { + t.Fatalf("ContainsSubnet failed: %v", err) + } else if contained { + t.Fatal("got s.ContainsSubnet(...) = true, want = false") + } +} + +func TestGetMainNICAddress(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + for _, tc := range []struct { + name string + address tcpip.Address + }{ + {"IPv4", "\x01\x01\x01\x01"}, + {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, + } { + t.Run(tc.name, func(t *testing.T) { + address := tc.address + mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) + subnet, err := tcpip.NewSubnet(address, mask) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, address); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + + // Check that we get the right initial address and subnet. + if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress failed: %v", err) + } else if gotAddress != address { + t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address) + } else if gotSubnet != subnet { + t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet) + } + + if err := s.RemoveSubnet(1, subnet); err != nil { + t.Fatalf("RemoveSubnet failed: %v", err) + } + + if err := s.RemoveAddress(1, address); err != nil { + t.Fatalf("RemoveAddress failed: %v", err) + } + + // Check that we get an error after removal. + if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress) + } + }) + } +} + func init() { stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { return &fakeNetworkProtocol{} |