diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state_test.go | 28 |
2 files changed, 39 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index adeebfe37..6e855d815 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -625,6 +625,17 @@ func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { return ok } +// JoinedGroups returns a list of groups the endpoint is a member of. +func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address { + a.mu.RLock() + defer a.mu.RUnlock() + groups := make([]tcpip.Address, 0, len(a.mu.groups)) + for g := range a.mu.groups { + groups = append(groups, g) + } + return groups +} + // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..0c8040c67 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -15,12 +15,40 @@ package stack_test import ( + "sort" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +func TestJoinedGroups(t *testing.T) { + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + + var ep fakeNetworkEndpoint + var s stack.AddressableEndpointState + s.Init(&ep) + + if joined, err := s.JoinGroup(addr1); err != nil { + t.Fatalf("JoinGroup(%s): %s", addr1, err) + } else if !joined { + t.Errorf("got JoinGroup(%s) = false, want = true", addr1) + } + if joined, err := s.JoinGroup(addr2); err != nil { + t.Fatalf("JoinGroup(%s): %s", addr2, err) + } else if !joined { + t.Errorf("got JoinGroup(%s) = false, want = true", addr2) + } + + joinedGroups := s.JoinedGroups() + sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] }) + if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" { + t.Errorf("joined groups mismatch (-want +got):\n%s", diff) + } +} + // TestAddressableEndpointStateCleanup tests that cleaning up an addressable // endpoint state removes permanent addresses and leaves groups. func TestAddressableEndpointStateCleanup(t *testing.T) { |