summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ip/BUILD1
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol.go145
-rw-r--r--pkg/tcpip/network/ip/generic_multicast_protocol_test.go412
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go88
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go20
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go18
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go27
-rw-r--r--pkg/tcpip/network/ipv6/mld.go38
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go20
-rw-r--r--test/packetimpact/runner/dut.go8
10 files changed, 489 insertions, 288 deletions
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD
index 6ca200b48..ca1247c1e 100644
--- a/pkg/tcpip/network/ip/BUILD
+++ b/pkg/tcpip/network/ip/BUILD
@@ -18,6 +18,7 @@ go_test(
srcs = ["generic_multicast_protocol_test.go"],
deps = [
":ip",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/faketime",
"@com_github_google_go_cmp//cmp:go_default_library",
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go
index e308550c4..bc7f7f637 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go
@@ -138,76 +138,89 @@ type MulticastGroupProtocol interface {
// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state
// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710.
//
+// Callers must synchronize accesses to the generic multicast protocol state;
+// GenericMulticastProtocolState obtains no locks in any of its methods. The
+// only exception to this is GenericMulticastProtocolState's timer/job callbacks
+// which will obtain the lock provided to the GenericMulticastProtocolState when
+// it is initialized.
+//
// GenericMulticastProtocolState.Init MUST be called before calling any of
// the methods on GenericMulticastProtocolState.
type GenericMulticastProtocolState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
opts GenericMulticastProtocolOptions
- mu struct {
- sync.RWMutex
+ // memberships holds group addresses and their associated state.
+ memberships map[tcpip.Address]multicastGroupState
- // memberships holds group addresses and their associated state.
- memberships map[tcpip.Address]multicastGroupState
- }
+ // protocolMU is the mutex used to protect the protocol.
+ protocolMU *sync.RWMutex
}
// Init initializes the Generic Multicast Protocol state.
-func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) {
- g.mu.Lock()
- defer g.mu.Unlock()
+//
+// Must only be called once for the lifetime of g.
+//
+// The GenericMulticastProtocolState will only grab the lock when timers/jobs
+// fire.
+func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) {
+ if g.memberships != nil {
+ panic("attempted to initialize generic membership protocol state twice")
+ }
+
g.opts = opts
- g.mu.memberships = make(map[tcpip.Address]multicastGroupState)
+ g.memberships = make(map[tcpip.Address]multicastGroupState)
+ g.protocolMU = protocolMU
}
-// MakeAllNonMember transitions all groups to the non-member state.
+// MakeAllNonMemberLocked transitions all groups to the non-member state.
//
// The groups will still be considered joined locally.
-func (g *GenericMulticastProtocolState) MakeAllNonMember() {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() {
if !g.opts.Enabled {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.transitionToNonMemberLocked(groupAddress, &info)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// InitializeGroups initializes each group, as if they were newly joined but
-// without affecting the groups' join count.
+// InitializeGroupsLocked initializes each group, as if they were newly joined
+// but without affecting the groups' join count.
//
// Must only be called after calling MakeAllNonMember as a group should not be
// initialized while it is not in the non-member state.
-func (g *GenericMulticastProtocolState) InitializeGroups() {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) InitializeGroupsLocked() {
if !g.opts.Enabled {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.initializeNewMemberLocked(groupAddress, &info)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// JoinGroup handles joining a new group.
+// JoinGroupLocked handles joining a new group.
//
// If dontInitialize is true, the group will be not be initialized and will be
// left in the non-member state - no packets will be sent for it until it is
// initialized via InitializeGroups.
-func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) {
- g.mu.Lock()
- defer g.mu.Unlock()
-
- if info, ok := g.mu.memberships[groupAddress]; ok {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) {
+ if info, ok := g.memberships[groupAddress]; ok {
// The group has already been joined.
info.joins++
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
return
}
@@ -217,15 +230,15 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do
// The state will be updated below, if required.
state: nonMember,
lastToSendReport: false,
- delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() {
- info, ok := g.mu.memberships[groupAddress]
+ delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() {
+ info, ok := g.memberships[groupAddress]
if !ok {
panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress))
}
info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil
info.state = idleMember
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}),
}
@@ -233,25 +246,24 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do
g.initializeNewMemberLocked(groupAddress, &info)
}
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
-// IsLocallyJoined returns true if the group is locally joined.
-func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool {
- g.mu.RLock()
- defer g.mu.RUnlock()
- _, ok := g.mu.memberships[groupAddress]
+// IsLocallyJoinedRLocked returns true if the group is locally joined.
+//
+// Precondition: g.protocolMU must be read locked.
+func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool {
+ _, ok := g.memberships[groupAddress]
return ok
}
-// LeaveGroup handles leaving the group.
+// LeaveGroupLocked handles leaving the group.
//
// Returns false if the group is not currently joined.
-func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool {
- g.mu.Lock()
- defer g.mu.Unlock()
-
- info, ok := g.mu.memberships[groupAddress]
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool {
+ info, ok := g.memberships[groupAddress]
if !ok {
return false
}
@@ -262,30 +274,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b
info.joins--
if info.joins != 0 {
// If we still have outstanding joins, then do nothing further.
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
return true
}
g.transitionToNonMemberLocked(groupAddress, &info)
- delete(g.mu.memberships, groupAddress)
+ delete(g.memberships, groupAddress)
return true
}
-// HandleQuery handles a query message with the specified maximum response time.
+// HandleQueryLocked handles a query message with the specified maximum response
+// time.
//
// If the group address is unspecified, then reports will be scheduled for all
// joined groups.
//
// Report(s) will be scheduled to be sent after a random duration between 0 and
// the maximum response time.
-func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) {
if !g.opts.Enabled {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
// As per RFC 2236 section 2.4 (for IGMPv2),
//
// In a Membership Query message, the group address field is set to zero
@@ -299,28 +311,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address,
// when sending a Multicast-Address-Specific Query.
if groupAddress.Unspecified() {
// This is a general query as the group address is unspecified.
- for groupAddress, info := range g.mu.memberships {
+ for groupAddress, info := range g.memberships {
g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
- } else if info, ok := g.mu.memberships[groupAddress]; ok {
+ } else if info, ok := g.memberships[groupAddress]; ok {
g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime)
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
-// HandleReport handles a report message.
+// HandleReportLocked handles a report message.
//
// If the report is for a joined group, any active delayed report will be
// cancelled and the host state for the group transitions to idle.
-func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) {
+//
+// Precondition: g.protocolMU must be locked.
+func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) {
if !g.opts.Enabled {
return
}
- g.mu.Lock()
- defer g.mu.Unlock()
-
// As per RFC 2236 section 3 pages 3-4 (for IGMPv2),
//
// If the host receives another host's Report (version 1 or 2) while it has
@@ -333,17 +344,17 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address)
// multicast address while it has a timer running for that same address
// on that interface, it stops its timer and does not send a Report for
// that address, thus suppressing duplicate reports on the link.
- if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember {
+ if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember {
info.delayedReportJob.Cancel()
info.lastToSendReport = false
info.state = idleMember
- g.mu.memberships[groupAddress] = info
+ g.memberships[groupAddress] = info
}
}
// initializeNewMemberLocked initializes a new group membership.
//
-// Precondition: g.mu must be locked.
+// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
if info.state != nonMember {
panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state))
@@ -465,7 +476,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres
// transitionToNonMemberLocked transitions the given multicast group the the
// non-member/listener state.
//
-// Precondition: e.mu must be locked.
+// Precondition: g.protocolMU must be locked.
func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) {
if info.state == nonMember {
return
@@ -479,7 +490,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress
// setDelayTimerForAddressRLocked sets timer to send a delay report.
//
-// Precondition: g.mu MUST be read locked.
+// Precondition: g.protocolMU MUST be read locked.
func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) {
if info.state == nonMember {
return
diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
index 670be30d4..6fd0eb9f7 100644
--- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/network/ip"
@@ -37,41 +38,86 @@ const (
var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil)
type mockMulticastGroupProtocol struct {
+ t *testing.T
+
+ mu sync.RWMutex
+
+ // Must only be accessed with mu held.
sendReportGroupAddrCount map[tcpip.Address]int
- sendLeaveGroupAddrCount map[tcpip.Address]int
+
+ // Must only be accessed with mu held.
+ sendLeaveGroupAddrCount map[tcpip.Address]int
}
func (m *mockMulticastGroupProtocol) init() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.initLocked()
+}
+
+func (m *mockMulticastGroupProtocol) initLocked() {
m.sendReportGroupAddrCount = make(map[tcpip.Address]int)
m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int)
}
func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress)
+ }
+
m.sendReportGroupAddrCount[groupAddress]++
return nil
}
func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
+ if m.mu.TryLock() {
+ m.mu.Unlock()
+ m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+ if m.mu.TryRLock() {
+ m.mu.RUnlock()
+ m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress)
+ }
+
m.sendLeaveGroupAddrCount[groupAddress]++
return nil
}
-func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
- sendReportGroupAddressesMap := make(map[tcpip.Address]int)
+func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ sendReportGroupAddrCount := make(map[tcpip.Address]int)
for _, a := range sendReportGroupAddresses {
- sendReportGroupAddressesMap[a] = 1
+ sendReportGroupAddrCount[a] = 1
}
- sendLeaveGroupAddressesMap := make(map[tcpip.Address]int)
+ sendLeaveGroupAddrCount := make(map[tcpip.Address]int)
for _, a := range sendLeaveGroupAddresses {
- sendLeaveGroupAddressesMap[a] = 1
+ sendLeaveGroupAddrCount[a] = 1
}
- diff := cmp.Diff(mockMulticastGroupProtocol{
- sendReportGroupAddrCount: sendReportGroupAddressesMap,
- sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap,
- }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{}))
- mgp.init()
+ diff := cmp.Diff(
+ &mockMulticastGroupProtocol{
+ sendReportGroupAddrCount: sendReportGroupAddrCount,
+ sendLeaveGroupAddrCount: sendLeaveGroupAddrCount,
+ },
+ m,
+ cmp.AllowUnexported(mockMulticastGroupProtocol{}),
+ // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t
+ cmp.FilterPath(
+ func(p cmp.Path) bool {
+ return p.Last().String() == ".mu" || p.Last().String() == ".t"
+ },
+ cmp.Ignore(),
+ ),
+ )
+ m.initLocked()
return diff
}
@@ -96,10 +142,11 @@ func TestJoinGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: true,
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
@@ -110,21 +157,24 @@ func TestJoinGroup(t *testing.T) {
// Joining a group should send a report immediately and another after
// a random interval between 0 and the maximum unsolicited report delay.
- g.JoinGroup(test.addr, false /* dontInitialize */)
+ mgp.mu.Lock()
+ g.JoinGroupLocked(test.addr, false /* dontInitialize */)
+ mgp.mu.Unlock()
if test.shouldSendReports {
- if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -152,10 +202,11 @@ func TestLeaveGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: true,
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
@@ -164,27 +215,36 @@ func TestLeaveGroup(t *testing.T) {
AllNodesAddress: addr2,
})
- g.JoinGroup(test.addr, false /* dontInitialize */)
+ mgp.mu.Lock()
+ g.JoinGroupLocked(test.addr, false /* dontInitialize */)
+ mgp.mu.Unlock()
if test.shouldSendMessages {
- if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Leaving a group should send a leave report immediately and cancel any
// delayed reports.
- if !g.LeaveGroup(test.addr) {
- t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr)
+ {
+ mgp.mu.Lock()
+ res := g.LeaveGroupLocked(test.addr)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr)
+ }
}
if test.shouldSendMessages {
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -227,10 +287,11 @@ func TestHandleReport(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: true,
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
@@ -239,32 +300,41 @@ func TestHandleReport(t *testing.T) {
AllNodesAddress: addr3,
})
- g.JoinGroup(addr1, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr2, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr3, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a report for a group we have a timer scheduled for should
// cancel our delayed report timer for the group.
- g.HandleReport(test.reportAddr)
+ mgp.mu.Lock()
+ g.HandleReportLocked(test.reportAddr)
+ mgp.mu.Unlock()
if len(test.expectReportsFor) != 0 {
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -313,10 +383,11 @@ func TestHandleQuery(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: true,
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
@@ -325,36 +396,45 @@ func TestHandleQuery(t *testing.T) {
AllNodesAddress: addr3,
})
- g.JoinGroup(addr1, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr2, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr3, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Receiving a query should make us schedule a new delayed report if it
// is a query directed at us or a general query.
- g.HandleQuery(test.queryAddr, test.maxDelay)
+ mgp.mu.Lock()
+ g.HandleQueryLocked(test.queryAddr, test.maxDelay)
+ mgp.mu.Unlock()
if len(test.expectReportsFor) != 0 {
clock.Advance(test.maxDelay)
- if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
@@ -363,10 +443,11 @@ func TestHandleQuery(t *testing.T) {
func TestJoinCount(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: true,
Rand: rand.New(rand.NewSource(4)),
Clock: clock,
@@ -375,70 +456,110 @@ func TestJoinCount(t *testing.T) {
})
// Set the join count to 2 for a group.
- g.JoinGroup(addr1, false /* dontInitialize */)
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
}
// Only the first join should trigger a report to be sent.
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr1, false /* dontInitialize */)
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ }
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if t.Failed() {
+ t.FailNow()
}
// Group should still be considered joined after leaving once.
- if !g.LeaveGroup(addr1) {
- t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
- }
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
+ }
+ if !isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
}
// A leave report should only be sent once the join count reaches 0.
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ }
+ if t.Failed() {
+ t.FailNow()
}
// Leaving once more should actually remove us from the group.
- if !g.LeaveGroup(addr1) {
- t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1)
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1)
+ }
+ if isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
+ }
}
- if g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if t.Failed() {
+ t.FailNow()
}
// Group should no longer be joined so we should not have anything to
// leave.
- if g.LeaveGroup(addr1) {
- t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1)
- }
- if g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1)
+ {
+ mgp.mu.Lock()
+ leaveGroupRes := g.LeaveGroupLocked(addr1)
+ isLocallyJoined := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if leaveGroupRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1)
+ }
+ if isLocallyJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1)
+ }
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
- t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should have no more messages to send.
+ //
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
func TestMakeAllNonMemberAndInitialize(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: true,
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
@@ -447,48 +568,62 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) {
AllNodesAddress: addr3,
})
- g.JoinGroup(addr1, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr2, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.JoinGroup(addr3, false /* dontInitialize */)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr3, false /* dontInitialize */)
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should send the leave reports for each but still consider them locally
// joined.
- g.MakeAllNonMember()
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.MakeAllNonMemberLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
for _, group := range []tcpip.Address{addr1, addr2, addr3} {
- if !g.IsLocallyJoined(group) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group)
+ mgp.mu.RLock()
+ res := g.IsLocallyJoinedRLocked(group)
+ mgp.mu.RUnlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group)
}
}
// Should send the initial set of unsolcited reports.
- g.InitializeGroups()
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ mgp.mu.Lock()
+ g.InitializeGroupsLocked()
+ mgp.mu.Unlock()
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
clock.Advance(maxUnsolicitedReportDelay)
- if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
// Should have no more messages to send.
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
}
@@ -521,10 +656,11 @@ func TestGroupStateNonMember(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var g ip.GenericMulticastProtocolState
- var mgp mockMulticastGroupProtocol
- mgp.init()
+ mgp := mockMulticastGroupProtocol{t: t}
clock := faketime.NewManualClock()
- g.Init(ip.GenericMulticastProtocolOptions{
+
+ mgp.init()
+ g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{
Enabled: test.enabled,
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
@@ -532,43 +668,65 @@ func TestGroupStateNonMember(t *testing.T) {
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
})
- g.JoinGroup(addr1, test.dontInitialize)
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
+ // Joining groups should not send any reports.
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr1, test.dontInitialize)
+ res := g.IsLocallyJoinedRLocked(addr1)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
-
- g.JoinGroup(addr2, test.dontInitialize)
- if !g.IsLocallyJoined(addr2) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2)
+ {
+ mgp.mu.Lock()
+ g.JoinGroupLocked(addr2, test.dontInitialize)
+ res := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !res {
+ t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2)
+ }
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- g.HandleQuery(addr1, time.Nanosecond)
+ // Receiving a query should not send any reports.
+ mgp.mu.Lock()
+ g.HandleQueryLocked(addr1, time.Nanosecond)
+ mgp.mu.Unlock()
+ // Generic multicast protocol timers are expected to take the job mutex.
clock.Advance(time.Nanosecond)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
- if !g.LeaveGroup(addr2) {
- t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2)
- }
- if !g.IsLocallyJoined(addr1) {
- t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1)
- }
- if g.IsLocallyJoined(addr2) {
- t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2)
+ // Leaving groups should not send any leave messages.
+ {
+ mgp.mu.Lock()
+ addr2LeaveRes := g.LeaveGroupLocked(addr2)
+ addr1IsJoined := g.IsLocallyJoinedRLocked(addr1)
+ addr2IsJoined := g.IsLocallyJoinedRLocked(addr2)
+ mgp.mu.Unlock()
+ if !addr2LeaveRes {
+ t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2)
+ }
+ if !addr1IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1)
+ }
+ if addr2IsJoined {
+ t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2)
+ }
}
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
clock.Advance(time.Hour)
- if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
+ if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" {
t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff)
}
})
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index 0134fadc0..2ee4c6445 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -16,7 +16,6 @@ package ipv4
import (
"fmt"
- "sync"
"sync/atomic"
"time"
@@ -68,8 +67,9 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil)
// igmpState.init() MUST be called after creating an IGMP state.
type igmpState struct {
// The IPv4 endpoint this igmpState is for.
- ep *endpoint
- opts IGMPOptions
+ ep *endpoint
+
+ genericMulticastProtocol ip.GenericMulticastProtocolState
// igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from
// RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1
@@ -84,16 +84,10 @@ type igmpState struct {
// when false.
igmpV1Present uint32
- mu struct {
- sync.RWMutex
-
- genericMulticastProtocol ip.GenericMulticastProtocolState
-
- // igmpV1Job is scheduled when this interface receives an IGMPv1 style
- // message, upon expiration the igmpV1Present flag is cleared.
- // igmpV1Job may not be nil once igmpState is initialized.
- igmpV1Job *tcpip.Job
- }
+ // igmpV1Job is scheduled when this interface receives an IGMPv1 style
+ // message, upon expiration the igmpV1Present flag is cleared.
+ // igmpV1Job may not be nil once igmpState is initialized.
+ igmpV1Job *tcpip.Job
}
// SendReport implements ip.MulticastGroupProtocol.
@@ -119,13 +113,12 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
-func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
+//
+// Must only be called once for the lifetime of igmp.
+func (igmp *igmpState) init(ep *endpoint) {
igmp.ep = ep
- igmp.opts = opts
- igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
- Enabled: opts.Enabled,
+ igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Enabled: ep.protocol.options.IGMP.Enabled,
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
@@ -133,11 +126,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) {
AllNodesAddress: header.IPv4AllSystems,
})
igmp.igmpV1Present = igmpV1PresentDefault
- igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() {
+ igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
igmp.setV1Present(false)
})
}
+// handleIGMP handles an IGMP packet.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) {
stats := igmp.ep.protocol.stack.Stats()
received := stats.IGMP.PacketsReceived
@@ -207,27 +203,28 @@ func (igmp *igmpState) setV1Present(v bool) {
}
}
+// handleMembershipQuery handles a membership query.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
-
// As per RFC 2236 Section 6, Page 10: If the maximum response time is zero
// then change the state to note that an IGMPv1 router is present and
// schedule the query received Job.
- if maxRespTime == 0 && igmp.opts.Enabled {
- igmp.mu.igmpV1Job.Cancel()
- igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout)
+ if maxRespTime == 0 && igmp.ep.protocol.options.IGMP.Enabled {
+ igmp.igmpV1Job.Cancel()
+ igmp.igmpV1Job.Schedule(v1RouterPresentTimeout)
igmp.setV1Present(true)
maxRespTime = v1MaxRespTime
}
- igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime)
+ igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime)
}
+// handleMembershipReport handles a membership report.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.HandleReport(groupAddress)
+ igmp.genericMulticastProtocol.HandleReportLocked(groupAddress)
}
// writePacket assembles and sends an IGMP packet with the provided fields,
@@ -278,28 +275,27 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip
//
// If the group already exists in the membership map, returns
// tcpip.ErrDuplicateAddress.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
+ igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */)
}
// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: igmp.ep.mu must be read locked.
func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress)
+ return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Leave Group message
// if required.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
-
// LeaveGroup returns false only if the group was not joined.
- if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
@@ -308,16 +304,16 @@ func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
// softLeaveAll leaves all groups from the perspective of IGMP, but remains
// joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) softLeaveAll() {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.MakeAllNonMember()
+ igmp.genericMulticastProtocol.MakeAllNonMemberLocked()
}
// initializeAll attemps to initialize the IGMP state for each group that has
// been joined locally.
+//
+// Precondition: igmp.ep.mu must be locked.
func (igmp *igmpState) initializeAll() {
- igmp.mu.Lock()
- defer igmp.mu.Unlock()
- igmp.mu.genericMulticastProtocol.InitializeGroups()
+ igmp.genericMulticastProtocol.InitializeGroupsLocked()
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 3076185cd..c63ecca4a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -72,7 +72,6 @@ type endpoint struct {
nic stack.NetworkInterface
dispatcher stack.TransportDispatcher
protocol *protocol
- igmp igmpState
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -84,6 +83,7 @@ type endpoint struct {
sync.RWMutex
addressableEndpointState stack.AddressableEndpointState
+ igmp igmpState
}
}
@@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.igmp.init(e, p.options.IGMP)
+ e.mu.igmp.init(e)
+ e.mu.Unlock()
return e
}
@@ -127,7 +129,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// endpoint may have left groups from the perspective of IGMP when the
// endpoint was disabled. Either way, we need to let routers know to
// send us multicast traffic.
- e.igmp.initializeAll()
+ e.mu.igmp.initializeAll()
// As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
// multicast group. Note, the IANA calls the all-hosts multicast group the
@@ -181,7 +183,7 @@ func (e *endpoint) disableLocked() {
// Leave groups from the perspective of IGMP so that routers know that
// we are no longer interested in the group.
- e.igmp.softLeaveAll()
+ e.mu.igmp.softLeaveAll()
// The address may have already been removed.
if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress {
@@ -718,7 +720,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
return
}
if p == header.IGMPProtocolNumber {
- e.igmp.handleIGMP(pkt)
+ e.mu.Lock()
+ e.mu.igmp.handleIGMP(pkt)
+ e.mu.Unlock()
return
}
if opts := h.Options(); len(opts) != 0 {
@@ -843,7 +847,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadAddress
}
- e.igmp.joinGroup(addr)
+ e.mu.igmp.joinGroup(addr)
return nil
}
@@ -858,14 +862,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
//
// Precondition: e.mu must be locked.
func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
- return e.igmp.leaveGroup(addr)
+ return e.mu.igmp.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.igmp.isInGroup(addr)
+ return e.mu.igmp.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 510276b8e..6ee162713 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -645,26 +645,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
}
case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone:
- var handler func(header.MLD)
switch icmpType {
case header.ICMPv6MulticastListenerQuery:
received.MulticastListenerQuery.Increment()
- handler = e.mld.handleMulticastListenerQuery
case header.ICMPv6MulticastListenerReport:
received.MulticastListenerReport.Increment()
- handler = e.mld.handleMulticastListenerReport
case header.ICMPv6MulticastListenerDone:
received.MulticastListenerDone.Increment()
default:
panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
}
+
if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize {
received.Invalid.Increment()
return
}
- if handler != nil {
- handler(header.MLD(payload.ToView()))
+ switch icmpType {
+ case header.ICMPv6MulticastListenerQuery:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerReport:
+ e.mu.Lock()
+ e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView()))
+ e.mu.Unlock()
+ case header.ICMPv6MulticastListenerDone:
+ default:
+ panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType))
}
default:
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 8bf84601f..7288e309c 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -85,9 +85,8 @@ type endpoint struct {
addressableEndpointState stack.AddressableEndpointState
ndp ndpState
+ mld mldState
}
-
- mld mldState
}
// NICNameFromID is a function that returns a stable name for the specified NIC,
@@ -232,7 +231,7 @@ func (e *endpoint) Enable() *tcpip.Error {
// endpoint may have left groups from the perspective of MLD when the
// endpoint was disabled. Either way, we need to let routers know to
// send us multicast traffic.
- e.mld.initializeAll()
+ e.mu.mld.initializeAll()
// Join the IPv6 All-Nodes Multicast group if the stack is configured to
// use IPv6. This is required to ensure that this node properly receives
@@ -349,7 +348,7 @@ func (e *endpoint) disableLocked() {
// Leave groups from the perspective of MLD so that routers know that
// we are no longer interested in the group.
- e.mld.softLeaveAll()
+ e.mu.mld.softLeaveAll()
}
// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
@@ -1417,7 +1416,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadAddress
}
- e.mld.joinGroup(addr)
+ e.mu.mld.joinGroup(addr)
return nil
}
@@ -1432,14 +1431,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error {
//
// Precondition: e.mu must be locked.
func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
- return e.mld.leaveGroup(addr)
+ return e.mu.mld.leaveGroup(addr)
}
// IsInGroup implements stack.GroupAddressableEndpoint.
func (e *endpoint) IsInGroup(addr tcpip.Address) bool {
e.mu.RLock()
defer e.mu.RUnlock()
- return e.mld.isInGroup(addr)
+ return e.mu.mld.isInGroup(addr)
}
var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
@@ -1504,17 +1503,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L
dispatcher: dispatcher,
protocol: p,
}
+ e.mu.Lock()
e.mu.addressableEndpointState.Init(e)
- e.mu.ndp = ndpState{
- ep: e,
- configs: p.options.NDPConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
- }
- e.mu.ndp.initializeTempAddrState()
- e.mld.init(e, p.options.MLD)
+ e.mu.ndp.init(e)
+ e.mu.mld.init(e)
+ e.mu.Unlock()
p.mu.Lock()
defer p.mu.Unlock()
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index 4c06b3f0c..6face17c6 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -67,10 +67,12 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error {
// init sets up an mldState struct, and is required to be called before using
// a new mldState.
-func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
+//
+// Must only be called once for the lifetime of mld.
+func (mld *mldState) init(ep *endpoint) {
mld.ep = ep
- mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{
- Enabled: opts.Enabled,
+ mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{
+ Enabled: ep.protocol.options.MLD.Enabled,
Rand: ep.protocol.stack.Rand(),
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
@@ -79,33 +81,45 @@ func (mld *mldState) init(ep *endpoint, opts MLDOptions) {
})
}
+// handleMulticastListenerQuery handles a query message.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) {
- mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
+ mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay())
}
+// handleMulticastListenerReport handles a report message.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) {
- mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress())
+ mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress())
}
// joinGroup handles joining a new group and sending and scheduling the required
// messages.
//
// If the group is already joined, returns tcpip.ErrDuplicateAddress.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) joinGroup(groupAddress tcpip.Address) {
- mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
+ mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */)
}
// isInGroup returns true if the specified group has been joined locally.
+//
+// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool {
- return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress)
+ return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress)
}
// leaveGroup handles removing the group from the membership map, cancels any
// delay timers associated with that group, and sends the Done message, if
// required.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
// LeaveGroup returns false only if the group was not joined.
- if mld.genericMulticastProtocol.LeaveGroup(groupAddress) {
+ if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) {
return nil
}
@@ -114,14 +128,18 @@ func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error {
// softLeaveAll leaves all groups from the perspective of MLD, but remains
// joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) softLeaveAll() {
- mld.genericMulticastProtocol.MakeAllNonMember()
+ mld.genericMulticastProtocol.MakeAllNonMemberLocked()
}
// initializeAll attemps to initialize the MLD state for each group that has
// been joined locally.
+//
+// Precondition: mld.ep.mu must be locked.
func (mld *mldState) initializeAll() {
- mld.genericMulticastProtocol.InitializeGroups()
+ mld.genericMulticastProtocol.InitializeGroupsLocked()
}
func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error {
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 8cb7d4dab..2f5e2e82c 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -20,6 +20,7 @@ import (
"math/rand"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() {
// ndpState is the per-interface NDP state.
type ndpState struct {
+ // Do not allow overwriting this state.
+ _ sync.NoCopy
+
// The IPv6 endpoint this ndpState is for.
ep *endpoint
@@ -1884,11 +1888,19 @@ func (ndp *ndpState) stopSolicitingRouters() {
ndp.rtrSolicitJob = nil
}
-// initializeTempAddrState initializes state related to temporary SLAAC
-// addresses.
-func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
+func (ndp *ndpState) init(ep *endpoint) {
+ if ndp.dad != nil {
+ panic("attempted to initialize NDP state twice")
+ }
+ ndp.ep = ep
+ ndp.configs = ep.protocol.options.NDPConfigs
+ ndp.dad = make(map[tcpip.Address]dadState)
+ ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState)
+ ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState)
+ ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState)
+
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index 8be2c6526..3e26c73cb 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -162,7 +162,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut
Image: "packetimpact",
CapAdd: []string{"NET_ADMIN"},
}
- if _, err := mountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil {
+ if _, err := MountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil {
return dutInfo{}, err
}
@@ -228,7 +228,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
Image: "packetimpact",
CapAdd: []string{"NET_ADMIN"},
}
- if _, err := mountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil {
+ if _, err := MountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil {
t.Fatal(err)
}
tbb := path.Base(testbenchBinary)
@@ -565,11 +565,11 @@ func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerut
return nil
}
-// mountTempDirectory creates a temporary directory on host with the template
+// MountTempDirectory creates a temporary directory on host with the template
// and then mounts it into the container under the name provided. The temporary
// directory name is returned. Content in that directory will be copied to
// TEST_UNDECLARED_OUTPUTS_DIR in cleanup phase.
-func mountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) {
+func MountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) {
t.Helper()
tmpDir, err := ioutil.TempDir("", hostDirTemplate)
if err != nil {