summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/ports
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2021-03-11 21:03:54 -0800
committergVisor bot <gvisor-bot@google.com>2021-03-11 21:05:32 -0800
commit82d7fb2cb0cf37ceeb44de665cde1ac7d72230f1 (patch)
treef88632143c3ab9c2f0067e9d4de8c6be47ea1f09 /pkg/tcpip/ports
parent192318a2316d84a3de9d28c29fbc73aae3e75206 (diff)
improve readability of ports package
Lots of small changes: - simplify package API via Reservation type - rename some single-letter variable names that were hard to follow - rename some types PiperOrigin-RevId: 362442366
Diffstat (limited to 'pkg/tcpip/ports')
-rw-r--r--pkg/tcpip/ports/BUILD5
-rw-r--r--pkg/tcpip/ports/flags.go150
-rw-r--r--pkg/tcpip/ports/ports.go603
-rw-r--r--pkg/tcpip/ports/ports_test.go26
4 files changed, 431 insertions, 353 deletions
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 57abec5c9..210262703 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "ports",
- srcs = ["ports.go"],
+ srcs = [
+ "flags.go",
+ "ports.go",
+ ],
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
diff --git a/pkg/tcpip/ports/flags.go b/pkg/tcpip/ports/flags.go
new file mode 100644
index 000000000..a8d7bff25
--- /dev/null
+++ b/pkg/tcpip/ports/flags.go
@@ -0,0 +1,150 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ports
+
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+
+ // TupleOnly represents TCP SO_REUSEADDR.
+ TupleOnly bool
+}
+
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
+ if f.MostRecent {
+ rf |= MostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= LoadBalancedFlag
+ }
+ if f.TupleOnly {
+ rf |= TupleOnlyFlag
+ }
+ return rf
+}
+
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
+
+const (
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // TupleOnlyFlag represents Flags.TupleOnly.
+ TupleOnlyFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
+ nextFlag
+
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
+
+ // MultiBindFlagMask contains the flags that allow binding the same
+ // tuple multiple times.
+ MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
+)
+
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ TupleOnly: f&TupleOnlyFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
+ refs [nextFlag]int
+}
+
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
+ var total int
+ for _, r := range c.refs {
+ total += r
+ }
+ return total
+}
+
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
+ var total int
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// SharedFlags returns the set of flags shared by all references.
+func (c FlagCounter) SharedFlags() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
+ if r > 0 {
+ intersection &= BitFlags(i)
+ }
+ }
+ return intersection
+}
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 101872b47..678199371 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+// Package ports provides PortManager that manages allocating, reserving and
+// releasing ports.
package ports
import (
@@ -25,165 +26,42 @@ import (
const anyIPAddress tcpip.Address = ""
-type portDescriptor struct {
- network tcpip.NetworkProtocolNumber
- transport tcpip.TransportProtocolNumber
- port uint16
-}
-
-// Flags represents the type of port reservation.
-//
-// +stateify savable
-type Flags struct {
- // MostRecent represents UDP SO_REUSEADDR.
- MostRecent bool
-
- // LoadBalanced indicates SO_REUSEPORT.
- //
- // LoadBalanced takes precidence over MostRecent.
- LoadBalanced bool
-
- // TupleOnly represents TCP SO_REUSEADDR.
- TupleOnly bool
-}
-
-// Bits converts the Flags to their bitset form.
-func (f Flags) Bits() BitFlags {
- var rf BitFlags
- if f.MostRecent {
- rf |= MostRecentFlag
- }
- if f.LoadBalanced {
- rf |= LoadBalancedFlag
- }
- if f.TupleOnly {
- rf |= TupleOnlyFlag
- }
- return rf
-}
-
-// Effective returns the effective behavior of a flag config.
-func (f Flags) Effective() Flags {
- e := f
- if e.LoadBalanced && e.MostRecent {
- e.MostRecent = false
- }
- return e
-}
-
-// PortManager manages allocating, reserving and releasing ports.
-type PortManager struct {
- // mu protects allocatedPorts.
- // LOCK ORDERING: mu > ephemeralMu.
- mu sync.RWMutex
- allocatedPorts map[portDescriptor]bindAddresses
-
- // ephemeralMu protects firstEphemeral and numEphemeral.
- ephemeralMu sync.RWMutex
- firstEphemeral uint16
- numEphemeral uint16
-
- // hint is used to pick ports ephemeral ports in a stable order for
- // a given port offset.
- //
- // hint must be accessed using the portHint/incPortHint helpers.
- // TODO(gvisor.dev/issue/940): S/R this field.
- hint uint32
-}
-
-// BitFlags is a bitset representation of Flags.
-type BitFlags uint32
+// Reservation describes a port reservation.
+type Reservation struct {
+ // Networks is a list of network protocols to which the reservation
+ // applies. Can be IPv4, IPv6, or both.
+ Networks []tcpip.NetworkProtocolNumber
-const (
- // MostRecentFlag represents Flags.MostRecent.
- MostRecentFlag BitFlags = 1 << iota
+ // Transport is the transport protocol to which the reservation applies.
+ Transport tcpip.TransportProtocolNumber
- // LoadBalancedFlag represents Flags.LoadBalanced.
- LoadBalancedFlag
+ // Addr is the address of the local endpoint.
+ Addr tcpip.Address
- // TupleOnlyFlag represents Flags.TupleOnly.
- TupleOnlyFlag
+ // Port is the local port number.
+ Port uint16
- // nextFlag is the value that the next added flag will have.
- //
- // It is used to calculate FlagMask below. It is also the number of
- // valid flag states.
- nextFlag
-
- // FlagMask is a bit mask for BitFlags.
- FlagMask = nextFlag - 1
-
- // MultiBindFlagMask contains the flags that allow binding the same
- // tuple multiple times.
- MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
-)
-
-// ToFlags converts the bitset into a Flags struct.
-func (f BitFlags) ToFlags() Flags {
- return Flags{
- MostRecent: f&MostRecentFlag != 0,
- LoadBalanced: f&LoadBalancedFlag != 0,
- TupleOnly: f&TupleOnlyFlag != 0,
- }
-}
+ // Flags describe features of the reservation.
+ Flags Flags
-// FlagCounter counts how many references each flag combination has.
-type FlagCounter struct {
- // refs stores the count for each possible flag combination, (0 though
- // FlagMask).
- refs [nextFlag]int
-}
-
-// AddRef increases the reference count for a specific flag combination.
-func (c *FlagCounter) AddRef(flags BitFlags) {
- c.refs[flags]++
-}
+ // BindToDevice is the NIC to which the reservation applies.
+ BindToDevice tcpip.NICID
-// DropRef decreases the reference count for a specific flag combination.
-func (c *FlagCounter) DropRef(flags BitFlags) {
- c.refs[flags]--
+ // Dest is the destination address.
+ Dest tcpip.FullAddress
}
-// TotalRefs calculates the total number of references for all flag
-// combinations.
-func (c FlagCounter) TotalRefs() int {
- var total int
- for _, r := range c.refs {
- total += r
- }
- return total
-}
-
-// FlagRefs returns the number of references with all specified flags.
-func (c FlagCounter) FlagRefs(flags BitFlags) int {
- var total int
- for i, r := range c.refs {
- if BitFlags(i)&flags == flags {
- total += r
- }
- }
- return total
-}
-
-// AllRefsHave returns if all references have all specified flags.
-func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
- for i, r := range c.refs {
- if BitFlags(i)&flags != flags && r > 0 {
- return false
- }
+func (rs Reservation) dst() destination {
+ return destination{
+ rs.Dest.Addr,
+ rs.Dest.Port,
}
- return true
}
-// IntersectionRefs returns the set of flags shared by all references.
-func (c FlagCounter) IntersectionRefs() BitFlags {
- intersection := FlagMask
- for i, r := range c.refs {
- if r > 0 {
- intersection &= BitFlags(i)
- }
- }
- return intersection
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
}
type destination struct {
@@ -191,18 +69,14 @@ type destination struct {
port uint16
}
-func makeDestination(a tcpip.FullAddress) destination {
- return destination{
- a.Addr,
- a.Port,
- }
-}
-
-// portNode is never empty. When it has no elements, it is removed from the
-// map that references it.
-type portNode map[destination]FlagCounter
+// destToCounter maps each destination to the FlagCounter that represents
+// endpoints to that destination.
+//
+// destToCounter is never empty. When it has no elements, it is removed from
+// the map that references it.
+type destToCounter map[destination]FlagCounter
-// intersectionRefs calculates the intersection of flag bit values which affect
+// intersectionFlags calculates the intersection of flag bit values which affect
// the specified destination.
//
// If no destinations are present, all flag values are returned as there are no
@@ -210,20 +84,20 @@ type portNode map[destination]FlagCounter
//
// In addition to the intersection, the number of intersecting refs is
// returned.
-func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
+func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
intersection := FlagMask
var count int
- for d, f := range p {
- if d == dst {
- intersection &= f.IntersectionRefs()
+ for dest, counter := range dc {
+ if dest == res.dst() {
+ intersection &= counter.SharedFlags()
count++
continue
}
// Wildcard destinations affect all destinations for TupleOnly.
- if d.addr == anyIPAddress || dst.addr == anyIPAddress {
+ if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs())
+ intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
count++
}
}
@@ -231,27 +105,29 @@ func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
return intersection, count
}
-// deviceNode is never empty. When it has no elements, it is removed from the
+// deviceToDest maps NICs to destinations for which there are port reservations.
+//
+// deviceToDest is never empty. When it has no elements, it is removed from the
// map that references it.
-type deviceNode map[tcpip.NICID]portNode
+type deviceToDest map[tcpip.NICID]destToCounter
-// isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all FlagCounters. If binding to a specific device, check
-// against the unspecified device and the provided device.
+// isAvailable checks whether binding is possible by device. If not binding to
+// a device, check against all FlagCounters. If binding to a specific device,
+// check against the unspecified device and the provided device.
//
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- flagBits := flags.Bits()
- if bindToDevice == 0 {
+func (dd deviceToDest) isAvailable(res Reservation) bool {
+ flagBits := res.Flags.Bits()
+ if res.BindToDevice == 0 {
intersection := FlagMask
- for _, p := range d {
- i, c := p.intersectionRefs(dst)
- if c == 0 {
+ for _, dest := range dd {
+ flags, count := dest.intersectionFlags(res)
+ if count == 0 {
continue
}
- intersection &= i
+ intersection &= flags
if intersection&flagBits == 0 {
// Can't bind because the (addr,port) was
// previously bound without reuse.
@@ -263,18 +139,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti
intersection := FlagMask
- if p, ok := d[0]; ok {
- var c int
- intersection, c = p.intersectionRefs(dst)
- if c > 0 && intersection&flagBits == 0 {
+ if dests, ok := dd[0]; ok {
+ var count int
+ intersection, count = dests.intersectionFlags(res)
+ if count > 0 && intersection&flagBits == 0 {
return false
}
}
- if p, ok := d[bindToDevice]; ok {
- i, c := p.intersectionRefs(dst)
- intersection &= i
- if c > 0 && intersection&flagBits == 0 {
+ if dests, ok := dd[res.BindToDevice]; ok {
+ flags, count := dests.intersectionFlags(res)
+ intersection &= flags
+ if count > 0 && intersection&flagBits == 0 {
return false
}
}
@@ -282,18 +158,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti
return true
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]deviceNode
+// addrToDevice maps IP addresses to NICs that have port reservations.
+type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- if addr == anyIPAddress {
- // If binding to the "any" address then check that there are no conflicts
- // with all addresses.
- for _, d := range b {
- if !d.isAvailable(flags, bindToDevice, dst) {
+func (ad addrToDevice) isAvailable(res Reservation) bool {
+ if res.Addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no
+ // conflicts with all addresses.
+ for _, devices := range ad {
+ if !devices.isAvailable(res) {
return false
}
}
@@ -301,15 +177,15 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
}
// Check that there is no conflict with the "any" address.
- if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(flags, bindToDevice, dst) {
+ if devices, ok := ad[anyIPAddress]; ok {
+ if !devices.isAvailable(res) {
return false
}
}
// Check that this is no conflict with the provided address.
- if d, ok := b[addr]; ok {
- if !d.isAvailable(flags, bindToDevice, dst) {
+ if devices, ok := ad[res.Addr]; ok {
+ if !devices.isAvailable(res) {
return false
}
}
@@ -317,10 +193,33 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
return true
}
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ // mu protects allocatedPorts.
+ // LOCK ORDERING: mu > ephemeralMu.
+ mu sync.RWMutex
+ // allocatedPorts is a nesting of maps that ultimately map Reservations
+ // to FlagCounters describing whether the Reservation is valid and can
+ // be reused.
+ allocatedPorts map[portDescriptor]addrToDevice
+
+ // ephemeralMu protects firstEphemeral and numEphemeral.
+ ephemeralMu sync.RWMutex
+ firstEphemeral uint16
+ numEphemeral uint16
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
+}
+
// NewPortManager creates new PortManager.
func NewPortManager() *PortManager {
return &PortManager{
- allocatedPorts: make(map[portDescriptor]bindAddresses),
+ allocatedPorts: make(map[portDescriptor]addrToDevice),
// Match Linux's default ephemeral range. See:
// https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842
firstEphemeral: 32768,
@@ -328,53 +227,57 @@ func NewPortManager() *PortManager {
}
}
+// PortTester indicates whether the passed in port is suitable. Returning an
+// error causes the function to which the PortTester is passed to return that
+// error.
+type PortTester func(port uint16) (good bool, err tcpip.Error)
+
// PickEphemeralPort randomly chooses a starting point and iterates over all
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- s.ephemeralMu.RLock()
- firstEphemeral := s.firstEphemeral
- numEphemeral := s.numEphemeral
- s.ephemeralMu.RUnlock()
+func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+ pm.ephemeralMu.RLock()
+ firstEphemeral := pm.firstEphemeral
+ numEphemeral := pm.numEphemeral
+ pm.ephemeralMu.RUnlock()
offset := uint16(rand.Int31n(int32(numEphemeral)))
return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
-// portHint atomically reads and returns the s.hint value.
-func (s *PortManager) portHint() uint16 {
- return uint16(atomic.LoadUint32(&s.hint))
+// portHint atomically reads and returns the pm.hint value.
+func (pm *PortManager) portHint() uint16 {
+ return uint16(atomic.LoadUint32(&pm.hint))
}
-// incPortHint atomically increments s.hint by 1.
-func (s *PortManager) incPortHint() {
- atomic.AddUint32(&s.hint, 1)
+// incPortHint atomically increments pm.hint by 1.
+func (pm *PortManager) incPortHint() {
+ atomic.AddUint32(&pm.hint, 1)
}
-// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// PickEphemeralPortStable starts at the specified offset + pm.portHint and
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (s *PortManager) PickEphemeralPortStable(offset uint16, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- s.ephemeralMu.RLock()
- firstEphemeral := s.firstEphemeral
- numEphemeral := s.numEphemeral
- s.ephemeralMu.RUnlock()
+func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ pm.ephemeralMu.RLock()
+ firstEphemeral := pm.firstEphemeral
+ numEphemeral := pm.numEphemeral
+ pm.ephemeralMu.RUnlock()
- p, err := pickEphemeralPort(s.portHint()+offset, firstEphemeral, numEphemeral, testPort)
+ p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort)
if err == nil {
- s.incPortHint()
+ pm.incPortHint()
}
return p, err
-
}
// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func pickEphemeralPort(offset, first, count uint16, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
+func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
for i := uint16(0); i < count; i++ {
port = first + (offset+i)%count
ok, err := testPort(port)
@@ -390,144 +293,145 @@ func pickEphemeralPort(offset, first, count uint16, testPort func(p uint16) (boo
return 0, &tcpip.ErrNoPortAvailable{}
}
-// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest))
-}
-
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, flags, bindToDevice, dst) {
- return false
- }
- }
- }
- return true
-}
-
// ReservePort marks a port/IP combination as reserved so that it cannot be
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
//
-// An optional testPort closure can be passed in which if provided will be used
-// to test if the picked port can be used. The function should return true if
-// the port is safe to use, false otherwise.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- dst := makeDestination(dest)
+// An optional PortTester can be passed in which if provided will be used to
+// test if the picked port can be used. The function should return true if the
+// port is safe to use, false otherwise.
+func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
- if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
+ if res.Port != 0 {
+ if !pm.reserveSpecificPortLocked(res) {
return 0, &tcpip.ErrPortInUse{}
}
- if testPort != nil && !testPort(port) {
- s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
- return 0, &tcpip.ErrPortInUse{}
+ if testPort != nil {
+ ok, err := testPort(res.Port)
+ if err != nil {
+ pm.releasePortLocked(res)
+ return 0, err
+ }
+ if !ok {
+ pm.releasePortLocked(res)
+ return 0, &tcpip.ErrPortInUse{}
+ }
}
- return port, nil
+ return res.Port, nil
}
// A port wasn't specified, so try to find one.
- return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
- if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
+ return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ res.Port = p
+ if !pm.reserveSpecificPortLocked(res) {
return false, nil
}
- if testPort != nil && !testPort(p) {
- s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst)
- return false, nil
+ if testPort != nil {
+ ok, err := testPort(p)
+ if err != nil {
+ pm.releasePortLocked(res)
+ return false, err
+ }
+ if !ok {
+ pm.releasePortLocked(res)
+ return false, nil
+ }
}
return true, nil
})
}
-// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) {
- return false
+// reserveSpecificPortLocked tries to reserve the given port on all given
+// protocols.
+func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool {
+ // Make sure the port is available.
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ if addrs, ok := pm.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(res) {
+ return false
+ }
+ }
}
- flagBits := flags.Bits()
-
// Reserve port on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- m, ok := s.allocatedPorts[desc]
+ flagBits := res.Flags.Bits()
+ dst := res.dst()
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
- m = make(bindAddresses)
- s.allocatedPorts[desc] = m
+ addrToDev = make(addrToDevice)
+ pm.allocatedPorts[desc] = addrToDev
}
- d, ok := m[addr]
+ devToDest, ok := addrToDev[res.Addr]
if !ok {
- d = make(deviceNode)
- m[addr] = d
+ devToDest = make(deviceToDest)
+ addrToDev[res.Addr] = devToDest
}
- p := d[bindToDevice]
- if p == nil {
- p = make(portNode)
+ destToCntr := devToDest[res.BindToDevice]
+ if destToCntr == nil {
+ destToCntr = make(destToCounter)
}
- n := p[dst]
- n.AddRef(flagBits)
- p[dst] = n
- d[bindToDevice] = p
+ counter := destToCntr[dst]
+ counter.AddRef(flagBits)
+ destToCntr[dst] = counter
+ devToDest[res.BindToDevice] = destToCntr
}
return true
}
// ReserveTuple adds a port reservation for the tuple on all given protocol.
-func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
- flagBits := flags.Bits()
- dst := makeDestination(dest)
+func (pm *PortManager) ReserveTuple(res Reservation) bool {
+ flagBits := res.Flags.Bits()
+ dst := res.dst()
- s.mu.Lock()
- defer s.mu.Unlock()
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
// It is easier to undo the entire reservation, so if we find that the
// tuple can't be fully added, finish and undo the whole thing.
undo := false
// Reserve port on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- m, ok := s.allocatedPorts[desc]
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
- m = make(bindAddresses)
- s.allocatedPorts[desc] = m
+ addrToDev = make(addrToDevice)
+ pm.allocatedPorts[desc] = addrToDev
}
- d, ok := m[addr]
+ devToDest, ok := addrToDev[res.Addr]
if !ok {
- d = make(deviceNode)
- m[addr] = d
+ devToDest = make(deviceToDest)
+ addrToDev[res.Addr] = devToDest
}
- p := d[bindToDevice]
- if p == nil {
- p = make(portNode)
+ destToCntr := devToDest[res.BindToDevice]
+ if destToCntr == nil {
+ destToCntr = make(destToCounter)
}
- n := p[dst]
- if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 {
+ counter := destToCntr[dst]
+ if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 {
// Tuple already exists.
undo = true
}
- n.AddRef(flagBits)
- p[dst] = n
- d[bindToDevice] = p
+ counter.AddRef(flagBits)
+ destToCntr[dst] = counter
+ devToDest[res.BindToDevice] = destToCntr
}
if undo {
// releasePortLocked decrements the counts (rather than setting
// them to zero), so it will undo the incorrect incrementing
// above.
- s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst)
+ pm.releasePortLocked(res)
return false
}
@@ -536,68 +440,71 @@ func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, trans
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) {
- s.mu.Lock()
- defer s.mu.Unlock()
+func (pm *PortManager) ReleasePort(res Reservation) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
- s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest))
+ pm.releasePortLocked(res)
}
-func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) {
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if m, ok := s.allocatedPorts[desc]; ok {
- d, ok := m[addr]
- if !ok {
- continue
- }
- p, ok := d[bindToDevice]
- if !ok {
- continue
- }
- n, ok := p[dst]
- if !ok {
- continue
- }
- n.DropRef(flags)
- if n.TotalRefs() > 0 {
- p[dst] = n
- continue
- }
- delete(p, dst)
- if len(p) > 0 {
- continue
- }
- delete(d, bindToDevice)
- if len(d) > 0 {
- continue
- }
- delete(m, addr)
- if len(m) > 0 {
- continue
- }
- delete(s.allocatedPorts, desc)
+func (pm *PortManager) releasePortLocked(res Reservation) {
+ dst := res.dst()
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
+ if !ok {
+ continue
+ }
+ devToDest, ok := addrToDev[res.Addr]
+ if !ok {
+ continue
+ }
+ destToCounter, ok := devToDest[res.BindToDevice]
+ if !ok {
+ continue
+ }
+ counter, ok := destToCounter[dst]
+ if !ok {
+ continue
+ }
+ counter.DropRef(res.Flags.Bits())
+ if counter.TotalRefs() > 0 {
+ destToCounter[dst] = counter
+ continue
+ }
+ delete(destToCounter, dst)
+ if len(destToCounter) > 0 {
+ continue
+ }
+ delete(devToDest, res.BindToDevice)
+ if len(devToDest) > 0 {
+ continue
+ }
+ delete(addrToDev, res.Addr)
+ if len(addrToDev) > 0 {
+ continue
}
+ delete(pm.allocatedPorts, desc)
}
}
// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
// both IPv4 and IPv6.
-func (s *PortManager) PortRange() (uint16, uint16) {
- s.ephemeralMu.RLock()
- defer s.ephemeralMu.RUnlock()
- return s.firstEphemeral, s.firstEphemeral + s.numEphemeral - 1
+func (pm *PortManager) PortRange() (uint16, uint16) {
+ pm.ephemeralMu.RLock()
+ defer pm.ephemeralMu.RUnlock()
+ return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1
}
// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
// (inclusive).
-func (s *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error {
+func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error {
if start > end {
return &tcpip.ErrInvalidPortRange{}
}
- s.ephemeralMu.Lock()
- defer s.ephemeralMu.Unlock()
- s.firstEphemeral = start
- s.numEphemeral = end - start + 1
+ pm.ephemeralMu.Lock()
+ defer pm.ephemeralMu.Unlock()
+ pm.firstEphemeral = start
+ pm.numEphemeral = end - start + 1
return nil
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 6cfac04b1..0f43dc8f8 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -331,15 +331,33 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
first, _ := pm.PortRange()
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
+ portRes := Reservation{
+ Networks: net,
+ Transport: fakeTransNumber,
+ Addr: test.ip,
+ Port: test.port,
+ Flags: test.flags,
+ BindToDevice: test.device,
+ Dest: test.dest,
+ }
+ pm.ReleasePort(portRes)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */)
+ portRes := Reservation{
+ Networks: net,
+ Transport: fakeTransNumber,
+ Addr: test.ip,
+ Port: test.port,
+ Flags: test.flags,
+ BindToDevice: test.device,
+ Dest: test.dest,
+ }
+ gotPort, err := pm.ReservePort(portRes, nil /* testPort */)
if diff := cmp.Diff(test.want, err); diff != "" {
- t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff)
+ t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
}
if test.port == 0 && (gotPort == 0 || gotPort < first) {
- t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, first)
+ t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first)
}
}
})