diff options
Diffstat (limited to 'pkg/tcpip/network/ipv4/ipv4.go')
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 85 |
1 files changed, 52 insertions, 33 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7c759be9a..be9c8e2f9 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -95,7 +95,7 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa protocol: p, } e.mu.addressableEndpointState.Init(e) - e.igmp.init(e) + e.igmp.init(e, p.options.IGMP) return e } @@ -126,7 +126,7 @@ func (e *endpoint) Enable() *tcpip.Error { // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) + _, err = e.joinGroupLocked(header.IPv4AllSystems) return err } @@ -164,7 +164,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if _, err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } @@ -200,7 +200,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { hdrLen := header.IPv4MinimumSize var opts header.IPv4Options if params.Options != nil { @@ -221,15 +221,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. - id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ TotalLength: length, ID: uint16(id), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + SrcAddr: srcAddr, + DstAddr: dstAddr, Options: opts, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -261,7 +261,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -349,7 +349,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) @@ -463,7 +463,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } @@ -706,10 +706,7 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } if p == header.IGMPProtocolNumber { - if e.protocol.options.IGMPEnabled { - e.igmp.handleIGMP(pkt) - } - // Nothing further to do with an IGMP packet, even if IGMP is not enabled. + e.igmp.handleIGMP(pkt) return } if opts := h.Options(); len(opts) != 0 { @@ -837,32 +834,55 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { // JoinGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { if !header.IsV4MulticastAddress(addr) { return false, tcpip.ErrBadAddress } + // TODO(gvisor.dev/issue/4916): Keep track of join count and IGMP state in a + // single type. + joined, err := e.mu.addressableEndpointState.JoinGroup(addr) + if err != nil || !joined { + return joined, err + } - e.mu.Lock() - defer e.mu.Unlock() - - joinedGroup, err := e.mu.addressableEndpointState.JoinGroup(addr) - if err == nil && joinedGroup && e.protocol.options.IGMPEnabled { - _ = e.igmp.joinGroup(addr) + // joinGroup only returns an error if we try to join a group twice, but we + // checked above to make sure that the group was newly joined. + if err := e.igmp.joinGroup(addr); err != nil { + panic(fmt.Sprintf("e.igmp.joinGroup(%s): %s", addr, err)) } - return joinedGroup, err + return true, nil } // LeaveGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup, but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) (bool, *tcpip.Error) { + left, err := e.mu.addressableEndpointState.LeaveGroup(addr) + if err != nil { + return left, err + } - leftGroup, err := e.mu.addressableEndpointState.LeaveGroup(addr) - if err == nil && leftGroup && e.protocol.options.IGMPEnabled { + if left { e.igmp.leaveGroup(addr) } - return leftGroup, err + return left, nil } // IsInGroup implements stack.GroupAddressableEndpoint. @@ -1021,20 +1041,19 @@ func addressToUint32(addr tcpip.Address) uint32 { return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } -// hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number and a 32-bit number to -// generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - a := addressToUint32(r.LocalAddress) - b := addressToUint32(r.RemoteAddress) +// hashRoute calculates a hash value for the given source/destination pair using +// the addresses, transport protocol number and a 32-bit number to generate the +// hash. +func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + a := addressToUint32(srcAddr) + b := addressToUint32(dstAddr) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } // Options holds options to configure a new protocol. type Options struct { - // IGMPEnabled indicates whether incoming IGMP packets will be handled and if - // this endpoint will transmit IGMP packets on IGMP related events. - IGMPEnabled bool + // IGMP holds options for IGMP. + IGMP IGMPOptions } // NewProtocolWithOptions returns an IPv4 network protocol. |