diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 61 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 50 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 8 |
4 files changed, 7 insertions, 121 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 6e855d815..b438e3acc 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,7 +48,6 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) } // ReadOnlyAddressableEndpointState provides read-only access to an @@ -335,11 +329,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -588,61 +577,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - -// JoinedGroups returns a list of groups the endpoint is a member of. -func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address { - a.mu.RLock() - defer a.mu.RUnlock() - groups := make([]tcpip.Address, 0, len(a.mu.groups)) - for g := range a.mu.groups { - groups = append(groups, g) - } - return groups -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 0c8040c67..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -15,40 +15,12 @@ package stack_test import ( - "sort" "testing" - "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) -func TestJoinedGroups(t *testing.T) { - const addr1 = tcpip.Address("\x01") - const addr2 = tcpip.Address("\x02") - - var ep fakeNetworkEndpoint - var s stack.AddressableEndpointState - s.Init(&ep) - - if joined, err := s.JoinGroup(addr1); err != nil { - t.Fatalf("JoinGroup(%s): %s", addr1, err) - } else if !joined { - t.Errorf("got JoinGroup(%s) = false, want = true", addr1) - } - if joined, err := s.JoinGroup(addr2); err != nil { - t.Fatalf("JoinGroup(%s): %s", addr2, err) - } else if !joined { - t.Errorf("got JoinGroup(%s) = false, want = true", addr2) - } - - joinedGroups := s.JoinedGroups() - sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] }) - if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" { - t.Errorf("joined groups mismatch (-want +got):\n%s", diff) - } -} - // TestAddressableEndpointStateCleanup tests that cleaning up an addressable // endpoint state removes permanent addresses and leaves groups. func TestAddressableEndpointStateCleanup(t *testing.T) { @@ -81,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5f3757de0..1805a8e0a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -563,8 +563,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -580,11 +579,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 43ca03ada..236e4e4c7 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -291,14 +291,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool |