diff options
-rw-r--r-- | pkg/tcpip/network/ip/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol.go | 145 | ||||
-rw-r--r-- | pkg/tcpip/network/ip/generic_multicast_protocol_test.go | 412 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 88 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/mld.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp.go | 20 | ||||
-rw-r--r-- | test/packetimpact/runner/dut.go | 8 |
10 files changed, 489 insertions, 288 deletions
diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index 6ca200b48..ca1247c1e 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -18,6 +18,7 @@ go_test( srcs = ["generic_multicast_protocol_test.go"], deps = [ ":ip", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/faketime", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index e308550c4..bc7f7f637 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -138,76 +138,89 @@ type MulticastGroupProtocol interface { // IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state // machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. // +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + opts GenericMulticastProtocolOptions - mu struct { - sync.RWMutex + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - } + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex } // Init initializes the Generic Multicast Protocol state. -func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { - g.mu.Lock() - defer g.mu.Unlock() +// +// Must only be called once for the lifetime of g. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + g.opts = opts - g.mu.memberships = make(map[tcpip.Address]multicastGroupState) + g.memberships = make(map[tcpip.Address]multicastGroupState) + g.protocolMU = protocolMU } -// MakeAllNonMember transitions all groups to the non-member state. +// MakeAllNonMemberLocked transitions all groups to the non-member state. // // The groups will still be considered joined locally. -func (g *GenericMulticastProtocolState) MakeAllNonMember() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.transitionToNonMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// InitializeGroups initializes each group, as if they were newly joined but -// without affecting the groups' join count. +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. // // Must only be called after calling MakeAllNonMember as a group should not be // initialized while it is not in the non-member state. -func (g *GenericMulticastProtocolState) InitializeGroups() { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.initializeNewMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// JoinGroup handles joining a new group. +// JoinGroupLocked handles joining a new group. // // If dontInitialize is true, the group will be not be initialized and will be // left in the non-member state - no packets will be sent for it until it is // initialized via InitializeGroups. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { - g.mu.Lock() - defer g.mu.Unlock() - - if info, ok := g.mu.memberships[groupAddress]; ok { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address, dontInitialize bool) { + if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return } @@ -217,15 +230,15 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do // The state will be updated below, if required. state: nonMember, lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { - info, ok := g.mu.memberships[groupAddress] + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info }), } @@ -233,25 +246,24 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do g.initializeNewMemberLocked(groupAddress, &info) } - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } -// IsLocallyJoined returns true if the group is locally joined. -func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { - g.mu.RLock() - defer g.mu.RUnlock() - _, ok := g.mu.memberships[groupAddress] +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] return ok } -// LeaveGroup handles leaving the group. +// LeaveGroupLocked handles leaving the group. // // Returns false if the group is not currently joined. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - - info, ok := g.mu.memberships[groupAddress] +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] if !ok { return false } @@ -262,30 +274,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b info.joins-- if info.joins != 0 { // If we still have outstanding joins, then do nothing further. - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return true } g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.mu.memberships, groupAddress) + delete(g.memberships, groupAddress) return true } -// HandleQuery handles a query message with the specified maximum response time. +// HandleQueryLocked handles a query message with the specified maximum response +// time. // // If the group address is unspecified, then reports will be scheduled for all // joined groups. // // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. -func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 2.4 (for IGMPv2), // // In a Membership Query message, the group address field is set to zero @@ -299,28 +311,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // when sending a Multicast-Address-Specific Query. if groupAddress.Unspecified() { // This is a general query as the group address is unspecified. - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } - } else if info, ok := g.mu.memberships[groupAddress]; ok { + } else if info, ok := g.memberships[groupAddress]; ok { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// HandleReport handles a report message. +// HandleReportLocked handles a report message. // // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. -func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { if !g.opts.Enabled { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), // // If the host receives another host's Report (version 1 or 2) while it has @@ -333,17 +344,17 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state == delayingMember { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } // initializeNewMemberLocked initializes a new group membership. // -// Precondition: g.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) @@ -465,7 +476,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres // transitionToNonMemberLocked transitions the given multicast group the the // non-member/listener state. // -// Precondition: e.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state == nonMember { return @@ -479,7 +490,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress // setDelayTimerForAddressRLocked sets timer to send a delay report. // -// Precondition: g.mu MUST be read locked. +// Precondition: g.protocolMU MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { if info.state == nonMember { return diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 670be30d4..6fd0eb9f7 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/ip" @@ -37,41 +38,86 @@ const ( var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) type mockMulticastGroupProtocol struct { + t *testing.T + + mu sync.RWMutex + + // Must only be accessed with mu held. sendReportGroupAddrCount map[tcpip.Address]int - sendLeaveGroupAddrCount map[tcpip.Address]int + + // Must only be accessed with mu held. + sendLeaveGroupAddrCount map[tcpip.Address]int } func (m *mockMulticastGroupProtocol) init() { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() +} + +func (m *mockMulticastGroupProtocol) initLocked() { m.sendReportGroupAddrCount = make(map[tcpip.Address]int) m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) } func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + m.sendReportGroupAddrCount[groupAddress]++ return nil } func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + m.sendLeaveGroupAddrCount[groupAddress]++ return nil } -func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - sendReportGroupAddressesMap := make(map[tcpip.Address]int) +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendReportGroupAddresses { - sendReportGroupAddressesMap[a] = 1 + sendReportGroupAddrCount[a] = 1 } - sendLeaveGroupAddressesMap := make(map[tcpip.Address]int) + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddressesMap[a] = 1 + sendLeaveGroupAddrCount[a] = 1 } - diff := cmp.Diff(mockMulticastGroupProtocol{ - sendReportGroupAddrCount: sendReportGroupAddressesMap, - sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap, - }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{})) - mgp.init() + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + return p.Last().String() == ".mu" || p.Last().String() == ".t" + }, + cmp.Ignore(), + ), + ) + m.initLocked() return diff } @@ -96,10 +142,11 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(0)), Clock: clock, @@ -110,21 +157,24 @@ func TestJoinGroup(t *testing.T) { // Joining a group should send a report immediately and another after // a random interval between 0 and the maximum unsolicited report delay. - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.mu.Lock() + g.JoinGroupLocked(test.addr, false /* dontInitialize */) + mgp.mu.Unlock() if test.shouldSendReports { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -152,10 +202,11 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(1)), Clock: clock, @@ -164,27 +215,36 @@ func TestLeaveGroup(t *testing.T) { AllNodesAddress: addr2, }) - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.mu.Lock() + g.JoinGroupLocked(test.addr, false /* dontInitialize */) + mgp.mu.Unlock() if test.shouldSendMessages { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Leaving a group should send a leave report immediately and cancel any // delayed reports. - if !g.LeaveGroup(test.addr) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr) + { + mgp.mu.Lock() + res := g.LeaveGroupLocked(test.addr) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.LeaveGroupLocked(%s) = false, want = true", test.addr) + } } if test.shouldSendMessages { - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -227,10 +287,11 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(2)), Clock: clock, @@ -239,32 +300,41 @@ func TestHandleReport(t *testing.T) { AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a report for a group we have a timer scheduled for should // cancel our delayed report timer for the group. - g.HandleReport(test.reportAddr) + mgp.mu.Lock() + g.HandleReportLocked(test.reportAddr) + mgp.mu.Unlock() if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -313,10 +383,11 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(3)), Clock: clock, @@ -325,36 +396,45 @@ func TestHandleQuery(t *testing.T) { AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a query should make us schedule a new delayed report if it // is a query directed at us or a general query. - g.HandleQuery(test.queryAddr, test.maxDelay) + mgp.mu.Lock() + g.HandleQueryLocked(test.queryAddr, test.maxDelay) + mgp.mu.Unlock() if len(test.expectReportsFor) != 0 { clock.Advance(test.maxDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -363,10 +443,11 @@ func TestHandleQuery(t *testing.T) { func TestJoinCount(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(4)), Clock: clock, @@ -375,70 +456,110 @@ func TestJoinCount(t *testing.T) { }) // Set the join count to 2 for a group. - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } } // Only the first join should trigger a report to be sent. - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if t.Failed() { + t.FailNow() } // Group should still be considered joined after leaving once. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) - } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if !isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } } // A leave report should only be sent once the join count reaches 0. - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Leaving once more should actually remove us from the group. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if t.Failed() { + t.FailNow() } // Group should no longer be joined so we should not have anything to // leave. - if g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1) - } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + { + mgp.mu.Lock() + leaveGroupRes := g.LeaveGroupLocked(addr1) + isLocallyJoined := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if leaveGroupRes { + t.Errorf("got g.LeaveGroupLocked(%s) = true, want = false", addr1) + } + if isLocallyJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr1) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } func TestMakeAllNonMemberAndInitialize(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: true, Rand: rand.New(rand.NewSource(3)), Clock: clock, @@ -447,48 +568,62 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.JoinGroupLocked(addr3, false /* dontInitialize */) + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should send the leave reports for each but still consider them locally // joined. - g.MakeAllNonMember() - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.MakeAllNonMemberLocked() + mgp.mu.Unlock() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !g.IsLocallyJoined(group) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group) + mgp.mu.RLock() + res := g.IsLocallyJoinedRLocked(group) + mgp.mu.RUnlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", group) } } // Should send the initial set of unsolcited reports. - g.InitializeGroups() - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.mu.Lock() + g.InitializeGroupsLocked() + mgp.mu.Unlock() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } @@ -521,10 +656,11 @@ func TestGroupStateNonMember(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ + + mgp.init() + g.Init(&mgp.mu, ip.GenericMulticastProtocolOptions{ Enabled: test.enabled, Rand: rand.New(rand.NewSource(3)), Clock: clock, @@ -532,43 +668,65 @@ func TestGroupStateNonMember(t *testing.T) { MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, }) - g.JoinGroup(addr1, test.dontInitialize) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + // Joining groups should not send any reports. + { + mgp.mu.Lock() + g.JoinGroupLocked(addr1, test.dontInitialize) + res := g.IsLocallyJoinedRLocked(addr1) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - - g.JoinGroup(addr2, test.dontInitialize) - if !g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2) + { + mgp.mu.Lock() + g.JoinGroupLocked(addr2, test.dontInitialize) + res := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !res { + t.Fatalf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr2) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.HandleQuery(addr1, time.Nanosecond) + // Receiving a query should not send any reports. + mgp.mu.Lock() + g.HandleQueryLocked(addr1, time.Nanosecond) + mgp.mu.Unlock() + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Nanosecond) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - if !g.LeaveGroup(addr2) { - t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2) - } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) - } - if g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2) + // Leaving groups should not send any leave messages. + { + mgp.mu.Lock() + addr2LeaveRes := g.LeaveGroupLocked(addr2) + addr1IsJoined := g.IsLocallyJoinedRLocked(addr1) + addr2IsJoined := g.IsLocallyJoinedRLocked(addr2) + mgp.mu.Unlock() + if !addr2LeaveRes { + t.Errorf("got g.LeaveGroupLocked(%s) = false, want = true", addr2) + } + if !addr1IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = false, want = true", addr1) + } + if addr2IsJoined { + t.Errorf("got g.IsLocallyJoinedRLocked(%s) = true, want = false", addr2) + } } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 0134fadc0..2ee4c6445 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -16,7 +16,6 @@ package ipv4 import ( "fmt" - "sync" "sync/atomic" "time" @@ -68,8 +67,9 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil) // igmpState.init() MUST be called after creating an IGMP state. type igmpState struct { // The IPv4 endpoint this igmpState is for. - ep *endpoint - opts IGMPOptions + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 @@ -84,16 +84,10 @@ type igmpState struct { // when false. igmpV1Present uint32 - mu struct { - sync.RWMutex - - genericMulticastProtocol ip.GenericMulticastProtocolState - - // igmpV1Job is scheduled when this interface receives an IGMPv1 style - // message, upon expiration the igmpV1Present flag is cleared. - // igmpV1Job may not be nil once igmpState is initialized. - igmpV1Job *tcpip.Job - } + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job } // SendReport implements ip.MulticastGroupProtocol. @@ -119,13 +113,12 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // init sets up an igmpState struct, and is required to be called before using // a new igmpState. -func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { - igmp.mu.Lock() - defer igmp.mu.Unlock() +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { igmp.ep = ep - igmp.opts = opts - igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Enabled: ep.protocol.options.IGMP.Enabled, Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: igmp, @@ -133,11 +126,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault - igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { igmp.setV1Present(false) }) } +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { stats := igmp.ep.protocol.stack.Stats() received := stats.IGMP.PacketsReceived @@ -207,27 +203,28 @@ func (igmp *igmpState) setV1Present(v bool) { } } +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero // then change the state to note that an IGMPv1 router is present and // schedule the query received Job. - if maxRespTime == 0 && igmp.opts.Enabled { - igmp.mu.igmpV1Job.Cancel() - igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + if maxRespTime == 0 && igmp.ep.protocol.options.IGMP.Enabled { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) igmp.setV1Present(true) maxRespTime = v1MaxRespTime } - igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime) + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) } +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) } // writePacket assembles and sends an IGMP packet with the provided fields, @@ -278,28 +275,27 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // // If the group already exists in the membership map, returns // tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { - igmp.mu.Lock() - defer igmp.mu.Unlock() - return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Leave Group message // if required. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // LeaveGroup returns false only if the group was not joined. - if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) { + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -308,16 +304,16 @@ func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of IGMP, but remains // joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) softLeaveAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.MakeAllNonMember() + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the IGMP state for each group that has // been joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) initializeAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.InitializeGroups() + igmp.genericMulticastProtocol.InitializeGroupsLocked() } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 3076185cd..c63ecca4a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,7 +72,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -84,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.igmp.init(e, p.options.IGMP) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -127,7 +129,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of IGMP when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.igmp.initializeAll() + e.mu.igmp.initializeAll() // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the @@ -181,7 +183,7 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of IGMP so that routers know that // we are no longer interested in the group. - e.igmp.softLeaveAll() + e.mu.igmp.softLeaveAll() // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { @@ -718,7 +720,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } if p == header.IGMPProtocolNumber { - e.igmp.handleIGMP(pkt) + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() return } if opts := h.Options(); len(opts) != 0 { @@ -843,7 +847,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.igmp.joinGroup(addr) + e.mu.igmp.joinGroup(addr) return nil } @@ -858,14 +862,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.igmp.leaveGroup(addr) + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.igmp.isInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 510276b8e..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -645,26 +645,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: - var handler func(header.MLD) switch icmpType { case header.ICMPv6MulticastListenerQuery: received.MulticastListenerQuery.Increment() - handler = e.mld.handleMulticastListenerQuery case header.ICMPv6MulticastListenerReport: received.MulticastListenerReport.Increment() - handler = e.mld.handleMulticastListenerReport case header.ICMPv6MulticastListenerDone: received.MulticastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { received.Invalid.Increment() return } - if handler != nil { - handler(header.MLD(payload.ToView())) + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } default: diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 8bf84601f..7288e309c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -85,9 +85,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } - - mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -232,7 +231,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of MLD when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.mld.initializeAll() + e.mu.mld.initializeAll() // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives @@ -349,7 +348,7 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. - e.mld.softLeaveAll() + e.mu.mld.softLeaveAll() } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -1417,7 +1416,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.mld.joinGroup(addr) + e.mu.mld.joinGroup(addr) return nil } @@ -1432,14 +1431,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.mld.leaveGroup(addr) + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mld.isInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -1504,17 +1503,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.options.NDPConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() - e.mld.init(e, p.options.MLD) + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 4c06b3f0c..6face17c6 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -67,10 +67,12 @@ func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // init sets up an mldState struct, and is required to be called before using // a new mldState. -func (mld *mldState) init(ep *endpoint, opts MLDOptions) { +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { mld.ep = ep - mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Enabled: ep.protocol.options.MLD.Enabled, Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: mld, @@ -79,33 +81,45 @@ func (mld *mldState) init(ep *endpoint, opts MLDOptions) { }) } +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) } +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) } // joinGroup handles joining a new group and sending and scheduling the required // messages. // // If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { - mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress, !mld.ep.Enabled() /* dontInitialize */) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { - return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Done message, if // required. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // LeaveGroup returns false only if the group was not joined. - if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -114,14 +128,18 @@ func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of MLD, but remains // joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) softLeaveAll() { - mld.genericMulticastProtocol.MakeAllNonMember() + mld.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the MLD state for each group that has // been joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) initializeAll() { - mld.genericMulticastProtocol.InitializeGroups() + mld.genericMulticastProtocol.InitializeGroupsLocked() } func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8cb7d4dab..2f5e2e82c 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -1884,11 +1888,19 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go index 8be2c6526..3e26c73cb 100644 --- a/test/packetimpact/runner/dut.go +++ b/test/packetimpact/runner/dut.go @@ -162,7 +162,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut Image: "packetimpact", CapAdd: []string{"NET_ADMIN"}, } - if _, err := mountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil { + if _, err := MountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil { return dutInfo{}, err } @@ -228,7 +228,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co Image: "packetimpact", CapAdd: []string{"NET_ADMIN"}, } - if _, err := mountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil { + if _, err := MountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil { t.Fatal(err) } tbb := path.Base(testbenchBinary) @@ -565,11 +565,11 @@ func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerut return nil } -// mountTempDirectory creates a temporary directory on host with the template +// MountTempDirectory creates a temporary directory on host with the template // and then mounts it into the container under the name provided. The temporary // directory name is returned. Content in that directory will be copied to // TEST_UNDECLARED_OUTPUTS_DIR in cleanup phase. -func mountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) { +func MountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) { t.Helper() tmpDir, err := ioutil.TempDir("", hostDirTemplate) if err != nil { |