diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4')
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/igmp_test.go | 355 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 85 |
3 files changed, 105 insertions, 392 deletions
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index c9bf117de..37f1822ca 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -51,6 +51,16 @@ const ( UnsolicitedReportIntervalMax = 10 * time.Second ) +// IGMPOptions holds options for IGMP. +type IGMPOptions struct { + // Enabled indicates whether IGMP will be performed. + // + // When enabled, IGMP may transmit IGMP report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // IGMP packets. + Enabled bool +} + var _ ip.MulticastGroupProtocol = (*igmpState)(nil) // igmpState is the per-interface IGMP state. @@ -58,7 +68,8 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil) // igmpState.init() MUST be called after creating an IGMP state. type igmpState struct { // The IPv4 endpoint this igmpState is for. - ep *endpoint + ep *endpoint + opts IGMPOptions // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 @@ -108,10 +119,11 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // init sets up an igmpState struct, and is required to be called before using // a new igmpState. -func (igmp *igmpState) init(ep *endpoint) { +func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { igmp.mu.Lock() defer igmp.mu.Unlock() igmp.ep = ep + igmp.opts = opts igmp.mu.genericMulticastProtocol.Init(ep.protocol.stack.Rand(), ep.protocol.stack.Clock(), igmp, UnsolicitedReportIntervalMax) igmp.igmpV1Present = igmpV1PresentDefault igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { @@ -189,6 +201,10 @@ func (igmp *igmpState) setV1Present(v bool) { } func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { + if !igmp.opts.Enabled { + return + } + igmp.mu.Lock() defer igmp.mu.Unlock() @@ -206,6 +222,10 @@ func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxResp } func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + if !igmp.opts.Enabled { + return + } + igmp.mu.Lock() defer igmp.mu.Unlock() igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) @@ -226,11 +246,8 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, // rather we should select an appropriate local address. - r := stack.Route{ - LocalAddress: header.IPv4Any, - RemoteAddress: destAddress, - } - igmp.ep.addIPHeader(&r, pkt, stack.NetworkHeaderParams{ + localAddr := header.IPv4Any + igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, TOS: stack.DefaultTOS, @@ -239,7 +256,7 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // TODO(b/162198658): set the ROUTER_ALERT option when sending Host // Membership Reports. sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent - if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() return err } @@ -263,6 +280,26 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // If the group already exists in the membership map, returns // tcpip.ErrDuplicateAddress. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { + if !igmp.opts.Enabled { + return nil + } + + // As per RFC 2236 section 6 page 10, + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // This is equivalent to not performing IGMP for the all-systems multicast + // address. Simply not performing IGMP when the group is added will prevent + // any work from being done on the all-systems multicast group when leaving + // the group or when query or report messages are received for it since the + // MGP state will not know about it. + if groupAddress == header.IPv4AllSystems { + return nil + } + igmp.mu.Lock() defer igmp.mu.Unlock() @@ -280,6 +317,10 @@ func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) *tcpip.Error { // If the group does not exist in the membership map, this function will // silently return. func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) { + if !igmp.opts.Enabled { + return + } + igmp.mu.Lock() defer igmp.mu.Unlock() igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 4873a336f..d83b6c4a4 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -15,9 +15,7 @@ package ipv4_test import ( - "fmt" "testing" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -30,25 +28,11 @@ import ( ) const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - // endpointAddr = tcpip.Address("\x0a\x00\x00\x02") + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 ) -var ( - // unsolicitedReportIntervalMaxTenthSec is the maximum amount of time the NIC - // will wait before sending an unsolicited report after joining a multicast - // group, in deciseconds. - unsolicitedReportIntervalMaxTenthSec = func() uint8 { - const decisecond = time.Second / 10 - if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { - panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) - } - return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) - }() -) - // validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet // sent to the provided address with the passed fields set. Raises a t.Error if // any field does not match. @@ -75,7 +59,9 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ - IGMPEnabled: igmpEnabled, + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, })}, Clock: clock, }) @@ -110,339 +96,6 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma }) } -// TestIgmpDisabled tests that IGMP is not enabled with a default -// stack.Options. This also tests that this NIC does not send the IGMP Join -// Group for the All Hosts group it automatically joins when created. -func TestIgmpDisabled(t *testing.T) { - e, s, _ := createStack(t, false) - - // This NIC will join the All Hosts group when created. Verify that does not - // send a report. - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got) - } - p, ok := e.Read() - if ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) - } - - // Test joining a specific group explicitly and verify that no reports are - // sent. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4.ProtocolNumber, %d, %s) = %s", nicID, multicastAddr, err) - } - - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 0 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 0", got) - } - p, ok = e.Read() - if ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) - } - - // Inject a General Membership Query, which is an IGMP Membership Query with - // a zeroed Group Address (IPv4Any) to verify that it does not reach the - // handler. - createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, unsolicitedReportIntervalMaxTenthSec, header.IPv4Any) - - if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 0 { - t.Fatalf("got Membership Queries received = %d, want = 0", got) - } - p, ok = e.Read() - if ok { - t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) - } -} - -// TestIgmpReceivesIGMPMessages tests that the IGMP stack increments packet -// counters when it receives properly formatted Membership Queries, Membership -// Reports, and LeaveGroup Messages sent to this address. Note: test includes -// IGMP header fields that are not explicitly tested in order to inject proper -// IGMP packets. -func TestIgmpReceivesIGMPMessages(t *testing.T) { - tests := []struct { - name string - headerType header.IGMPType - maxRespTime byte - groupAddress tcpip.Address - statCounter func(tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter - }{ - { - name: "General Membership Query", - headerType: header.IGMPMembershipQuery, - maxRespTime: unsolicitedReportIntervalMaxTenthSec, - groupAddress: header.IPv4Any, - statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { - return stats.MembershipQuery - }, - }, - { - name: "IGMPv1 Membership Report", - headerType: header.IGMPv1MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { - return stats.V1MembershipReport - }, - }, - { - name: "IGMPv2 Membership Report", - headerType: header.IGMPv2MembershipReport, - maxRespTime: 0, - groupAddress: header.IPv4AllSystems, - statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { - return stats.V2MembershipReport - }, - }, - { - name: "Leave Group", - headerType: header.IGMPLeaveGroup, - maxRespTime: 0, - groupAddress: header.IPv4AllRoutersGroup, - statCounter: func(stats tcpip.IGMPReceivedPacketStats) *tcpip.StatCounter { - return stats.LeaveGroup - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) - - createAndInjectIGMPPacket(e, test.headerType, test.maxRespTime, test.groupAddress) - - if got := test.statCounter(s.Stats().IGMP.PacketsReceived).Value(); got != 1 { - t.Fatalf("got %s received = %d, want = 1", test.name, got) - } - }) - } -} - -// TestIgmpJoinGroup tests that when explicitly joining a multicast group, the -// IGMP stack schedules and sends correct Membership Reports. -func TestIgmpJoinGroup(t *testing.T) { - e, s, clock := createStack(t, true) - - // Test joining a specific address explicitly and verify a Membership Report - // is sent immediately. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) - if t.Failed() { - t.FailNow() - } - - // Verify the second Membership Report is sent after a random interval up to - // the maximum unsolicited report interval. - p, ok = e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) -} - -// TestIgmpLeaveGroup tests that when leaving a previously joined multicast -// group the IGMP enabled NIC sends the appropriate message. -func TestIgmpLeaveGroup(t *testing.T) { - e, s, clock := createStack(t, true) - - // Join a group so that it can be left, validate the immediate Membership - // Report is sent only to the multicast address joined. - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) - if t.Failed() { - t.FailNow() - } - - // Verify the second Membership Report is sent after a random interval up to - // the maximum unsolicited report interval, and is sent to the multicast - // address being joined. - p, ok = e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) - if t.Failed() { - t.FailNow() - } - - // Now that there are no packets queued and none scheduled to be sent, leave - // the group. - if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // Observe the Leave Group Message to verify that the Leave Group message is - // sent to the All Routers group but that the message itself has the - // multicast address being left. - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected LeaveGroup") - } - if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 1 { - t.Fatalf("got LeaveGroup messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, header.IPv4AllRoutersGroup, header.IGMPLeaveGroup, 0, multicastAddr) -} - -// TestIgmpJoinLeaveGroup tests that when leaving a previously joined multicast -// group before the Unsolicited Report Interval cancels the second membership -// report. -func TestIgmpJoinLeaveGroup(t *testing.T) { - _, s, clock := createStack(t, true) - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // Verify that this NIC sent a Membership Report for only the group just - // joined. - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - - if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // Wait for the standard IGMP Unsolicited Report Interval duration before - // verifying that the unsolicited Membership Report was sent after leaving - // the group. - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } -} - -// TestIgmpMembershipQueryReport tests the handling of both incoming IGMP -// Membership Queries and outgoing Membership Reports. -func TestIgmpMembershipQueryReport(t *testing.T) { - e, s, clock := createStack(t, true) - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) - if t.Failed() { - t.FailNow() - } - - p, ok = e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) - - // Inject a General Membership Query, which is an IGMP Membership Query with - // a zeroed Group Address (IPv4Any) with the shortened Max Response Time. - const maxRespTimeDS = 10 - createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, maxRespTimeDS, header.IPv4Any) - - p, ok = e.Read() - if ok { - t.Fatalf("sent unexpected packet, expected V2MembershipReport only after advancing the clock = %+v", p.Pkt) - } - clock.Advance(header.DecisecondToDuration(maxRespTimeDS)) - p, ok = e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 3 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 3", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) -} - -// TestIgmpMultipleHosts tests the handling of IGMP Leave when we are not the -// most recent IGMP host to join a multicast network. -func TestIgmpMultipleHosts(t *testing.T) { - e, s, clock := createStack(t, true) - - if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - p, ok := e.Read() - if !ok { - t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") - } - if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { - t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) - } - validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) - if t.Failed() { - t.FailNow() - } - - // Inject another Host's Join Group message so that this host is not the - // latest to send the report. Set Max Response Time to 0 for Membership - // Reports. - createAndInjectIGMPPacket(e, header.IGMPv2MembershipReport, 0, multicastAddr) - - if err := s.LeaveGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { - t.Fatalf("LeaveGroup(ipv4, nic, %s) = %s", multicastAddr, err) - } - - // Wait to be sure that no Leave Group messages were sent up to the max - // unsolicited report interval since it was not the last host to join this - // group. - clock.Advance(ipv4.UnsolicitedReportIntervalMax) - if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != 0 { - t.Fatalf("got LeaveGroup messages sent = %d, want = 0", got) - } -} - // TestIgmpV1Present tests the handling of the case where an IGMPv1 router is // present on the network. The IGMP stack will then send IGMPv1 Membership // reports for backwards compatibility. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7c759be9a..be9c8e2f9 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -95,7 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa protocol: p, } e.mu.addressableEndpointState.Init(e) - e.igmp.init(e) + e.igmp.init(e, p.options.IGMP) return e } @@ -126,7 +126,7 @@ func (e *endpoint) Enable() *tcpip.Error { // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) + _, err = e.joinGroupLocked(header.IPv4AllSystems) return err } @@ -164,7 +164,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if _, err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } @@ -200,7 +200,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { hdrLen := header.IPv4MinimumSize var opts header.IPv4Options if params.Options != nil { @@ -221,15 +221,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. - id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ TotalLength: length, ID: uint16(id), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: srcAddr, + DstAddr: dstAddr, Options: opts, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -261,7 +261,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -349,7 +349,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) @@ -463,7 +463,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } @@ -706,10 +706,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } if p == header.IGMPProtocolNumber { - if e.protocol.options.IGMPEnabled { - e.igmp.handleIGMP(pkt) - } - // Nothing further to do with an IGMP packet, even if IGMP is not enabled. + e.igmp.handleIGMP(pkt) return } if opts := h.Options(); len(opts) != 0 { @@ -837,32 +834,55 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { // JoinGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { if !header.IsV4MulticastAddress(addr) { return false, tcpip.ErrBadAddress } + // TODO(gvisor.dev/issue/4916): Keep track of join count and IGMP state in a + // single type. + joined, err := e.mu.addressableEndpointState.JoinGroup(addr) + if err != nil || !joined { + return joined, err + } - e.mu.Lock() - defer e.mu.Unlock() - - joinedGroup, err := e.mu.addressableEndpointState.JoinGroup(addr) - if err == nil && joinedGroup && e.protocol.options.IGMPEnabled { - _ = e.igmp.joinGroup(addr) + // joinGroup only returns an error if we try to join a group twice, but we + // checked above to make sure that the group was newly joined. + if err := e.igmp.joinGroup(addr); err != nil { + panic(fmt.Sprintf("e.igmp.joinGroup(%s): %s", addr, err)) } - return joinedGroup, err + return true, nil } // LeaveGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { + left, err := e.mu.addressableEndpointState.LeaveGroup(addr) + if err != nil { + return left, err + } - leftGroup, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err == nil && leftGroup && e.protocol.options.IGMPEnabled { + if left { e.igmp.leaveGroup(addr) } - return leftGroup, err + return left, nil } // IsInGroup implements stack.GroupAddressableEndpoint. @@ -1021,20 +1041,19 @@ func addressToUint32(addr tcpip.Address) uint32 { return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } -// hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number and a 32-bit number to -// generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - a := addressToUint32(r.LocalAddress) - b := addressToUint32(r.RemoteAddress) +// hashRoute calculates a hash value for the given source/destination pair using +// the addresses, transport protocol number and a 32-bit number to generate the +// hash. +func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + a := addressToUint32(srcAddr) + b := addressToUint32(dstAddr) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } // Options holds options to configure a new protocol. type Options struct { - // IGMPEnabled indicates whether incoming IGMP packets will be handled and if - // this endpoint will transmit IGMP packets on IGMP related events. - IGMPEnabled bool + // IGMP holds options for IGMP. + IGMP IGMPOptions } // NewProtocolWithOptions returns an IPv4 network protocol. |