From a085e562d0592bccc99e9e0380706a8025f70d53 Mon Sep 17 00:00:00 2001
From: Ian Gudger <igudger@google.com>
Date: Wed, 10 Jun 2020 23:48:03 -0700
Subject: Add support for SO_REUSEADDR to UDP sockets/endpoints.

On UDP sockets, SO_REUSEADDR allows multiple sockets to bind to the same
address, but only delivers packets to the most recently bound socket. This
differs from the behavior of SO_REUSEADDR on TCP sockets. SO_REUSEADDR for TCP
sockets will likely need an almost completely independent implementation.

SO_REUSEADDR has some odd interactions with the similar SO_REUSEPORT. These
interactions are tested fairly extensively and all but one particularly odd
one (that honestly seems like a bug) behave the same on gVisor and Linux.

PiperOrigin-RevId: 315844832
---
 pkg/tcpip/ports/ports.go | 116 ++++++++++++++++++++++++++++++++---------------
 1 file changed, 79 insertions(+), 37 deletions(-)

(limited to 'pkg/tcpip/ports')

diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index b937cb84b..edc29ad27 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -54,17 +54,27 @@ type Flags struct {
 	LoadBalanced bool
 }
 
-func (f Flags) bits() reuseFlag {
-	var rf reuseFlag
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+	var rf BitFlags
 	if f.MostRecent {
-		rf |= mostRecentFlag
+		rf |= MostRecentFlag
 	}
 	if f.LoadBalanced {
-		rf |= loadBalancedFlag
+		rf |= LoadBalancedFlag
 	}
 	return rf
 }
 
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+	e := f
+	if e.LoadBalanced && e.MostRecent {
+		e.MostRecent = false
+	}
+	return e
+}
+
 // PortManager manages allocating, reserving and releasing ports.
 type PortManager struct {
 	mu             sync.RWMutex
@@ -78,56 +88,88 @@ type PortManager struct {
 	hint uint32
 }
 
-type reuseFlag int
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
 
 const (
-	mostRecentFlag reuseFlag = 1 << iota
-	loadBalancedFlag
+	// MostRecentFlag represents Flags.MostRecent.
+	MostRecentFlag BitFlags = 1 << iota
+
+	// LoadBalancedFlag represents Flags.LoadBalanced.
+	LoadBalancedFlag
+
+	// nextFlag is the value that the next added flag will have.
+	//
+	// It is used to calculate FlagMask below. It is also the number of
+	// valid flag states.
 	nextFlag
 
-	flagMask = nextFlag - 1
+	// FlagMask is a bit mask for BitFlags.
+	FlagMask = nextFlag - 1
 )
 
-type portNode struct {
-	// refs stores the count for each possible flag combination.
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+	return Flags{
+		MostRecent:   f&MostRecentFlag != 0,
+		LoadBalanced: f&LoadBalancedFlag != 0,
+	}
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+	// refs stores the count for each possible flag combination, (0 though
+	// FlagMask).
 	refs [nextFlag]int
 }
 
-func (p portNode) totalRefs() int {
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+	c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+	c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
 	var total int
-	for _, r := range p.refs {
+	for _, r := range c.refs {
 		total += r
 	}
 	return total
 }
 
-// flagRefs returns the number of references with all specified flags.
-func (p portNode) flagRefs(flags reuseFlag) int {
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
 	var total int
-	for i, r := range p.refs {
-		if reuseFlag(i)&flags == flags {
+	for i, r := range c.refs {
+		if BitFlags(i)&flags == flags {
 			total += r
 		}
 	}
 	return total
 }
 
-// allRefsHave returns if all references have all specified flags.
-func (p portNode) allRefsHave(flags reuseFlag) bool {
-	for i, r := range p.refs {
-		if reuseFlag(i)&flags == flags && r > 0 {
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+	for i, r := range c.refs {
+		if BitFlags(i)&flags != flags && r > 0 {
 			return false
 		}
 	}
 	return true
 }
 
-// intersectionRefs returns the set of flags shared by all references.
-func (p portNode) intersectionRefs() reuseFlag {
-	intersection := flagMask
-	for i, r := range p.refs {
+// IntersectionRefs returns the set of flags shared by all references.
+func (c FlagCounter) IntersectionRefs() BitFlags {
+	intersection := FlagMask
+	for i, r := range c.refs {
 		if r > 0 {
-			intersection &= reuseFlag(i)
+			intersection &= BitFlags(i)
 		}
 	}
 	return intersection
@@ -135,26 +177,26 @@ func (p portNode) intersectionRefs() reuseFlag {
 
 // deviceNode is never empty. When it has no elements, it is removed from the
 // map that references it.
-type deviceNode map[tcpip.NICID]portNode
+type deviceNode map[tcpip.NICID]FlagCounter
 
 // isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all portNodes. If binding to a specific device, check
+// device, check against all FlagCounters. If binding to a specific device, check
 // against the unspecified device and the provided device.
 //
 // If either of the port reuse flags is enabled on any of the nodes, all nodes
 // sharing a port must share at least one reuse flag. This matches Linux's
 // behavior.
 func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
-	flagBits := flags.bits()
+	flagBits := flags.Bits()
 	if bindToDevice == 0 {
 		// Trying to binding all devices.
 		if flagBits == 0 {
 			// Can't bind because the (addr,port) is already bound.
 			return false
 		}
-		intersection := flagMask
+		intersection := FlagMask
 		for _, p := range d {
-			i := p.intersectionRefs()
+			i := p.IntersectionRefs()
 			intersection &= i
 			if intersection&flagBits == 0 {
 				// Can't bind because the (addr,port) was
@@ -165,17 +207,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
 		return true
 	}
 
-	intersection := flagMask
+	intersection := FlagMask
 
 	if p, ok := d[0]; ok {
-		intersection = p.intersectionRefs()
+		intersection = p.IntersectionRefs()
 		if intersection&flagBits == 0 {
 			return false
 		}
 	}
 
 	if p, ok := d[bindToDevice]; ok {
-		i := p.intersectionRefs()
+		i := p.IntersectionRefs()
 		intersection &= i
 		if intersection&flagBits == 0 {
 			return false
@@ -324,7 +366,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
 	if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) {
 		return false
 	}
-	flagBits := flags.bits()
+	flagBits := flags.Bits()
 
 	// Reserve port on all network protocols.
 	for _, network := range networks {
@@ -340,7 +382,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
 			m[addr] = d
 		}
 		n := d[bindToDevice]
-		n.refs[flagBits]++
+		n.AddRef(flagBits)
 		d[bindToDevice] = n
 	}
 
@@ -353,7 +395,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	flagBits := flags.bits()
+	flagBits := flags.Bits()
 
 	for _, network := range networks {
 		desc := portDescriptor{network, transport, port}
@@ -368,7 +410,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
 			}
 			n.refs[flagBits]--
 			d[bindToDevice] = n
-			if n.refs == [nextFlag]int{} {
+			if n.TotalRefs() == 0 {
 				delete(d, bindToDevice)
 			}
 			if len(d) == 0 {
-- 
cgit v1.2.3