summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go11
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go28
2 files changed, 39 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index adeebfe37..6e855d815 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -625,6 +625,17 @@ func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool {
return ok
}
+// JoinedGroups returns a list of groups the endpoint is a member of.
+func (a *AddressableEndpointState) JoinedGroups() []tcpip.Address {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+ groups := make([]tcpip.Address, 0, len(a.mu.groups))
+ for g := range a.mu.groups {
+ groups = append(groups, g)
+ }
+ return groups
+}
+
// Cleanup forcefully leaves all groups and removes all permanent addresses.
func (a *AddressableEndpointState) Cleanup() {
a.mu.Lock()
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index 26787d0a3..0c8040c67 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -15,12 +15,40 @@
package stack_test
import (
+ "sort"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+func TestJoinedGroups(t *testing.T) {
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+
+ var ep fakeNetworkEndpoint
+ var s stack.AddressableEndpointState
+ s.Init(&ep)
+
+ if joined, err := s.JoinGroup(addr1); err != nil {
+ t.Fatalf("JoinGroup(%s): %s", addr1, err)
+ } else if !joined {
+ t.Errorf("got JoinGroup(%s) = false, want = true", addr1)
+ }
+ if joined, err := s.JoinGroup(addr2); err != nil {
+ t.Fatalf("JoinGroup(%s): %s", addr2, err)
+ } else if !joined {
+ t.Errorf("got JoinGroup(%s) = false, want = true", addr2)
+ }
+
+ joinedGroups := s.JoinedGroups()
+ sort.Slice(joinedGroups, func(i, j int) bool { return joinedGroups[i][0] < joinedGroups[j][0] })
+ if diff := cmp.Diff([]tcpip.Address{addr1, addr2}, joinedGroups); diff != "" {
+ t.Errorf("joined groups mismatch (-want +got):\n%s", diff)
+ }
+}
+
// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
// endpoint state removes permanent addresses and leaves groups.
func TestAddressableEndpointStateCleanup(t *testing.T) {