diff options
author | Kevin Krakauer <krakauer@google.com> | 2021-03-11 21:03:54 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-03-11 21:05:32 -0800 |
commit | 82d7fb2cb0cf37ceeb44de665cde1ac7d72230f1 (patch) | |
tree | f88632143c3ab9c2f0067e9d4de8c6be47ea1f09 /pkg/tcpip/ports | |
parent | 192318a2316d84a3de9d28c29fbc73aae3e75206 (diff) |
improve readability of ports package
Lots of small changes:
- simplify package API via Reservation type
- rename some single-letter variable names that were hard to follow
- rename some types
PiperOrigin-RevId: 362442366
Diffstat (limited to 'pkg/tcpip/ports')
-rw-r--r-- | pkg/tcpip/ports/BUILD | 5 | ||||
-rw-r--r-- | pkg/tcpip/ports/flags.go | 150 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports.go | 603 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports_test.go | 26 |
4 files changed, 431 insertions, 353 deletions
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 57abec5c9..210262703 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "ports", - srcs = ["ports.go"], + srcs = [ + "flags.go", + "ports.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/sync", diff --git a/pkg/tcpip/ports/flags.go b/pkg/tcpip/ports/flags.go new file mode 100644 index 000000000..a8d7bff25 --- /dev/null +++ b/pkg/tcpip/ports/flags.go @@ -0,0 +1,150 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ports + +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool +} + +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags + if f.MostRecent { + rf |= MostRecentFlag + } + if f.LoadBalanced { + rf |= LoadBalancedFlag + } + if f.TupleOnly { + rf |= TupleOnlyFlag + } + return rf +} + +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 + +const ( + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. + nextFlag + + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag +) + +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). + refs [nextFlag]int +} + +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { + var total int + for _, r := range c.refs { + total += r + } + return total +} + +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { + var total int + for i, r := range c.refs { + if BitFlags(i)&flags == flags { + total += r + } + } + return total +} + +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { + return false + } + } + return true +} + +// SharedFlags returns the set of flags shared by all references. +func (c FlagCounter) SharedFlags() BitFlags { + intersection := FlagMask + for i, r := range c.refs { + if r > 0 { + intersection &= BitFlags(i) + } + } + return intersection +} diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 101872b47..678199371 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ports provides PortManager that manages allocating, reserving and releasing ports. +// Package ports provides PortManager that manages allocating, reserving and +// releasing ports. package ports import ( @@ -25,165 +26,42 @@ import ( const anyIPAddress tcpip.Address = "" -type portDescriptor struct { - network tcpip.NetworkProtocolNumber - transport tcpip.TransportProtocolNumber - port uint16 -} - -// Flags represents the type of port reservation. -// -// +stateify savable -type Flags struct { - // MostRecent represents UDP SO_REUSEADDR. - MostRecent bool - - // LoadBalanced indicates SO_REUSEPORT. - // - // LoadBalanced takes precidence over MostRecent. - LoadBalanced bool - - // TupleOnly represents TCP SO_REUSEADDR. - TupleOnly bool -} - -// Bits converts the Flags to their bitset form. -func (f Flags) Bits() BitFlags { - var rf BitFlags - if f.MostRecent { - rf |= MostRecentFlag - } - if f.LoadBalanced { - rf |= LoadBalancedFlag - } - if f.TupleOnly { - rf |= TupleOnlyFlag - } - return rf -} - -// Effective returns the effective behavior of a flag config. -func (f Flags) Effective() Flags { - e := f - if e.LoadBalanced && e.MostRecent { - e.MostRecent = false - } - return e -} - -// PortManager manages allocating, reserving and releasing ports. -type PortManager struct { - // mu protects allocatedPorts. - // LOCK ORDERING: mu > ephemeralMu. - mu sync.RWMutex - allocatedPorts map[portDescriptor]bindAddresses - - // ephemeralMu protects firstEphemeral and numEphemeral. - ephemeralMu sync.RWMutex - firstEphemeral uint16 - numEphemeral uint16 - - // hint is used to pick ports ephemeral ports in a stable order for - // a given port offset. - // - // hint must be accessed using the portHint/incPortHint helpers. - // TODO(gvisor.dev/issue/940): S/R this field. - hint uint32 -} - -// BitFlags is a bitset representation of Flags. -type BitFlags uint32 +// Reservation describes a port reservation. +type Reservation struct { + // Networks is a list of network protocols to which the reservation + // applies. Can be IPv4, IPv6, or both. + Networks []tcpip.NetworkProtocolNumber -const ( - // MostRecentFlag represents Flags.MostRecent. - MostRecentFlag BitFlags = 1 << iota + // Transport is the transport protocol to which the reservation applies. + Transport tcpip.TransportProtocolNumber - // LoadBalancedFlag represents Flags.LoadBalanced. - LoadBalancedFlag + // Addr is the address of the local endpoint. + Addr tcpip.Address - // TupleOnlyFlag represents Flags.TupleOnly. - TupleOnlyFlag + // Port is the local port number. + Port uint16 - // nextFlag is the value that the next added flag will have. - // - // It is used to calculate FlagMask below. It is also the number of - // valid flag states. - nextFlag - - // FlagMask is a bit mask for BitFlags. - FlagMask = nextFlag - 1 - - // MultiBindFlagMask contains the flags that allow binding the same - // tuple multiple times. - MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag -) - -// ToFlags converts the bitset into a Flags struct. -func (f BitFlags) ToFlags() Flags { - return Flags{ - MostRecent: f&MostRecentFlag != 0, - LoadBalanced: f&LoadBalancedFlag != 0, - TupleOnly: f&TupleOnlyFlag != 0, - } -} + // Flags describe features of the reservation. + Flags Flags -// FlagCounter counts how many references each flag combination has. -type FlagCounter struct { - // refs stores the count for each possible flag combination, (0 though - // FlagMask). - refs [nextFlag]int -} - -// AddRef increases the reference count for a specific flag combination. -func (c *FlagCounter) AddRef(flags BitFlags) { - c.refs[flags]++ -} + // BindToDevice is the NIC to which the reservation applies. + BindToDevice tcpip.NICID -// DropRef decreases the reference count for a specific flag combination. -func (c *FlagCounter) DropRef(flags BitFlags) { - c.refs[flags]-- + // Dest is the destination address. + Dest tcpip.FullAddress } -// TotalRefs calculates the total number of references for all flag -// combinations. -func (c FlagCounter) TotalRefs() int { - var total int - for _, r := range c.refs { - total += r - } - return total -} - -// FlagRefs returns the number of references with all specified flags. -func (c FlagCounter) FlagRefs(flags BitFlags) int { - var total int - for i, r := range c.refs { - if BitFlags(i)&flags == flags { - total += r - } - } - return total -} - -// AllRefsHave returns if all references have all specified flags. -func (c FlagCounter) AllRefsHave(flags BitFlags) bool { - for i, r := range c.refs { - if BitFlags(i)&flags != flags && r > 0 { - return false - } +func (rs Reservation) dst() destination { + return destination{ + rs.Dest.Addr, + rs.Dest.Port, } - return true } -// IntersectionRefs returns the set of flags shared by all references. -func (c FlagCounter) IntersectionRefs() BitFlags { - intersection := FlagMask - for i, r := range c.refs { - if r > 0 { - intersection &= BitFlags(i) - } - } - return intersection +type portDescriptor struct { + network tcpip.NetworkProtocolNumber + transport tcpip.TransportProtocolNumber + port uint16 } type destination struct { @@ -191,18 +69,14 @@ type destination struct { port uint16 } -func makeDestination(a tcpip.FullAddress) destination { - return destination{ - a.Addr, - a.Port, - } -} - -// portNode is never empty. When it has no elements, it is removed from the -// map that references it. -type portNode map[destination]FlagCounter +// destToCounter maps each destination to the FlagCounter that represents +// endpoints to that destination. +// +// destToCounter is never empty. When it has no elements, it is removed from +// the map that references it. +type destToCounter map[destination]FlagCounter -// intersectionRefs calculates the intersection of flag bit values which affect +// intersectionFlags calculates the intersection of flag bit values which affect // the specified destination. // // If no destinations are present, all flag values are returned as there are no @@ -210,20 +84,20 @@ type portNode map[destination]FlagCounter // // In addition to the intersection, the number of intersecting refs is // returned. -func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { +func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { intersection := FlagMask var count int - for d, f := range p { - if d == dst { - intersection &= f.IntersectionRefs() + for dest, counter := range dc { + if dest == res.dst() { + intersection &= counter.SharedFlags() count++ continue } // Wildcard destinations affect all destinations for TupleOnly. - if d.addr == anyIPAddress || dst.addr == anyIPAddress { + if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) count++ } } @@ -231,27 +105,29 @@ func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { return intersection, count } -// deviceNode is never empty. When it has no elements, it is removed from the +// deviceToDest maps NICs to destinations for which there are port reservations. +// +// deviceToDest is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]portNode +type deviceToDest map[tcpip.NICID]destToCounter -// isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all FlagCounters. If binding to a specific device, check -// against the unspecified device and the provided device. +// isAvailable checks whether binding is possible by device. If not binding to +// a device, check against all FlagCounters. If binding to a specific device, +// check against the unspecified device and the provided device. // // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - flagBits := flags.Bits() - if bindToDevice == 0 { +func (dd deviceToDest) isAvailable(res Reservation) bool { + flagBits := res.Flags.Bits() + if res.BindToDevice == 0 { intersection := FlagMask - for _, p := range d { - i, c := p.intersectionRefs(dst) - if c == 0 { + for _, dest := range dd { + flags, count := dest.intersectionFlags(res) + if count == 0 { continue } - intersection &= i + intersection &= flags if intersection&flagBits == 0 { // Can't bind because the (addr,port) was // previously bound without reuse. @@ -263,18 +139,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti intersection := FlagMask - if p, ok := d[0]; ok { - var c int - intersection, c = p.intersectionRefs(dst) - if c > 0 && intersection&flagBits == 0 { + if dests, ok := dd[0]; ok { + var count int + intersection, count = dests.intersectionFlags(res) + if count > 0 && intersection&flagBits == 0 { return false } } - if p, ok := d[bindToDevice]; ok { - i, c := p.intersectionRefs(dst) - intersection &= i - if c > 0 && intersection&flagBits == 0 { + if dests, ok := dd[res.BindToDevice]; ok { + flags, count := dests.intersectionFlags(res) + intersection &= flags + if count > 0 && intersection&flagBits == 0 { return false } } @@ -282,18 +158,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti return true } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]deviceNode +// addrToDevice maps IP addresses to NICs that have port reservations. +type addrToDevice map[tcpip.Address]deviceToDest // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - if addr == anyIPAddress { - // If binding to the "any" address then check that there are no conflicts - // with all addresses. - for _, d := range b { - if !d.isAvailable(flags, bindToDevice, dst) { +func (ad addrToDevice) isAvailable(res Reservation) bool { + if res.Addr == anyIPAddress { + // If binding to the "any" address then check that there are no + // conflicts with all addresses. + for _, devices := range ad { + if !devices.isAvailable(res) { return false } } @@ -301,15 +177,15 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice } // Check that there is no conflict with the "any" address. - if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice, dst) { + if devices, ok := ad[anyIPAddress]; ok { + if !devices.isAvailable(res) { return false } } // Check that this is no conflict with the provided address. - if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice, dst) { + if devices, ok := ad[res.Addr]; ok { + if !devices.isAvailable(res) { return false } } @@ -317,10 +193,33 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice return true } +// PortManager manages allocating, reserving and releasing ports. +type PortManager struct { + // mu protects allocatedPorts. + // LOCK ORDERING: mu > ephemeralMu. + mu sync.RWMutex + // allocatedPorts is a nesting of maps that ultimately map Reservations + // to FlagCounters describing whether the Reservation is valid and can + // be reused. + allocatedPorts map[portDescriptor]addrToDevice + + // ephemeralMu protects firstEphemeral and numEphemeral. + ephemeralMu sync.RWMutex + firstEphemeral uint16 + numEphemeral uint16 + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 +} + // NewPortManager creates new PortManager. func NewPortManager() *PortManager { return &PortManager{ - allocatedPorts: make(map[portDescriptor]bindAddresses), + allocatedPorts: make(map[portDescriptor]addrToDevice), // Match Linux's default ephemeral range. See: // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 firstEphemeral: 32768, @@ -328,53 +227,57 @@ func NewPortManager() *PortManager { } } +// PortTester indicates whether the passed in port is suitable. Returning an +// error causes the function to which the PortTester is passed to return that +// error. +type PortTester func(port uint16) (good bool, err tcpip.Error) + // PickEphemeralPort randomly chooses a starting point and iterates over all // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - s.ephemeralMu.RLock() - firstEphemeral := s.firstEphemeral - numEphemeral := s.numEphemeral - s.ephemeralMu.RUnlock() +func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { + pm.ephemeralMu.RLock() + firstEphemeral := pm.firstEphemeral + numEphemeral := pm.numEphemeral + pm.ephemeralMu.RUnlock() offset := uint16(rand.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } -// portHint atomically reads and returns the s.hint value. -func (s *PortManager) portHint() uint16 { - return uint16(atomic.LoadUint32(&s.hint)) +// portHint atomically reads and returns the pm.hint value. +func (pm *PortManager) portHint() uint16 { + return uint16(atomic.LoadUint32(&pm.hint)) } -// incPortHint atomically increments s.hint by 1. -func (s *PortManager) incPortHint() { - atomic.AddUint32(&s.hint, 1) +// incPortHint atomically increments pm.hint by 1. +func (pm *PortManager) incPortHint() { + atomic.AddUint32(&pm.hint, 1) } -// PickEphemeralPortStable starts at the specified offset + s.portHint and +// PickEphemeralPortStable starts at the specified offset + pm.portHint and // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint16, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - s.ephemeralMu.RLock() - firstEphemeral := s.firstEphemeral - numEphemeral := s.numEphemeral - s.ephemeralMu.RUnlock() +func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { + pm.ephemeralMu.RLock() + firstEphemeral := pm.firstEphemeral + numEphemeral := pm.numEphemeral + pm.ephemeralMu.RUnlock() - p, err := pickEphemeralPort(s.portHint()+offset, firstEphemeral, numEphemeral, testPort) + p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort) if err == nil { - s.incPortHint() + pm.incPortHint() } return p, err - } // pickEphemeralPort starts at the offset specified from the FirstEphemeral port // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func pickEphemeralPort(offset, first, count uint16, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { +func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { for i := uint16(0); i < count; i++ { port = first + (offset+i)%count ok, err := testPort(port) @@ -390,144 +293,145 @@ func pickEphemeralPort(offset, first, count uint16, testPort func(p uint16) (boo return 0, &tcpip.ErrNoPortAvailable{} } -// IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) -} - -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice, dst) { - return false - } - } - } - return true -} - // ReservePort marks a port/IP combination as reserved so that it cannot be // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. // -// An optional testPort closure can be passed in which if provided will be used -// to test if the picked port can be used. The function should return true if -// the port is safe to use, false otherwise. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() - - dst := makeDestination(dest) +// An optional PortTester can be passed in which if provided will be used to +// test if the picked port can be used. The function should return true if the +// port is safe to use, false otherwise. +func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { + pm.mu.Lock() + defer pm.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. - if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { + if res.Port != 0 { + if !pm.reserveSpecificPortLocked(res) { return 0, &tcpip.ErrPortInUse{} } - if testPort != nil && !testPort(port) { - s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) - return 0, &tcpip.ErrPortInUse{} + if testPort != nil { + ok, err := testPort(res.Port) + if err != nil { + pm.releasePortLocked(res) + return 0, err + } + if !ok { + pm.releasePortLocked(res) + return 0, &tcpip.ErrPortInUse{} + } } - return port, nil + return res.Port, nil } // A port wasn't specified, so try to find one. - return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { - if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { + return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + res.Port = p + if !pm.reserveSpecificPortLocked(res) { return false, nil } - if testPort != nil && !testPort(p) { - s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst) - return false, nil + if testPort != nil { + ok, err := testPort(p) + if err != nil { + pm.releasePortLocked(res) + return false, err + } + if !ok { + pm.releasePortLocked(res) + return false, nil + } } return true, nil }) } -// reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { - return false +// reserveSpecificPortLocked tries to reserve the given port on all given +// protocols. +func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool { + // Make sure the port is available. + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + if addrs, ok := pm.allocatedPorts[desc]; ok { + if !addrs.isAvailable(res) { + return false + } + } } - flagBits := flags.Bits() - // Reserve port on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - m, ok := s.allocatedPorts[desc] + flagBits := res.Flags.Bits() + dst := res.dst() + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] if !ok { - m = make(bindAddresses) - s.allocatedPorts[desc] = m + addrToDev = make(addrToDevice) + pm.allocatedPorts[desc] = addrToDev } - d, ok := m[addr] + devToDest, ok := addrToDev[res.Addr] if !ok { - d = make(deviceNode) - m[addr] = d + devToDest = make(deviceToDest) + addrToDev[res.Addr] = devToDest } - p := d[bindToDevice] - if p == nil { - p = make(portNode) + destToCntr := devToDest[res.BindToDevice] + if destToCntr == nil { + destToCntr = make(destToCounter) } - n := p[dst] - n.AddRef(flagBits) - p[dst] = n - d[bindToDevice] = p + counter := destToCntr[dst] + counter.AddRef(flagBits) + destToCntr[dst] = counter + devToDest[res.BindToDevice] = destToCntr } return true } // ReserveTuple adds a port reservation for the tuple on all given protocol. -func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { - flagBits := flags.Bits() - dst := makeDestination(dest) +func (pm *PortManager) ReserveTuple(res Reservation) bool { + flagBits := res.Flags.Bits() + dst := res.dst() - s.mu.Lock() - defer s.mu.Unlock() + pm.mu.Lock() + defer pm.mu.Unlock() // It is easier to undo the entire reservation, so if we find that the // tuple can't be fully added, finish and undo the whole thing. undo := false // Reserve port on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - m, ok := s.allocatedPorts[desc] + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] if !ok { - m = make(bindAddresses) - s.allocatedPorts[desc] = m + addrToDev = make(addrToDevice) + pm.allocatedPorts[desc] = addrToDev } - d, ok := m[addr] + devToDest, ok := addrToDev[res.Addr] if !ok { - d = make(deviceNode) - m[addr] = d + devToDest = make(deviceToDest) + addrToDev[res.Addr] = devToDest } - p := d[bindToDevice] - if p == nil { - p = make(portNode) + destToCntr := devToDest[res.BindToDevice] + if destToCntr == nil { + destToCntr = make(destToCounter) } - n := p[dst] - if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + counter := destToCntr[dst] + if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 { // Tuple already exists. undo = true } - n.AddRef(flagBits) - p[dst] = n - d[bindToDevice] = p + counter.AddRef(flagBits) + destToCntr[dst] = counter + devToDest[res.BindToDevice] = destToCntr } if undo { // releasePortLocked decrements the counts (rather than setting // them to zero), so it will undo the incorrect incrementing // above. - s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + pm.releasePortLocked(res) return false } @@ -536,68 +440,71 @@ func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, trans // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { - s.mu.Lock() - defer s.mu.Unlock() +func (pm *PortManager) ReleasePort(res Reservation) { + pm.mu.Lock() + defer pm.mu.Unlock() - s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) + pm.releasePortLocked(res) } -func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if m, ok := s.allocatedPorts[desc]; ok { - d, ok := m[addr] - if !ok { - continue - } - p, ok := d[bindToDevice] - if !ok { - continue - } - n, ok := p[dst] - if !ok { - continue - } - n.DropRef(flags) - if n.TotalRefs() > 0 { - p[dst] = n - continue - } - delete(p, dst) - if len(p) > 0 { - continue - } - delete(d, bindToDevice) - if len(d) > 0 { - continue - } - delete(m, addr) - if len(m) > 0 { - continue - } - delete(s.allocatedPorts, desc) +func (pm *PortManager) releasePortLocked(res Reservation) { + dst := res.dst() + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] + if !ok { + continue + } + devToDest, ok := addrToDev[res.Addr] + if !ok { + continue + } + destToCounter, ok := devToDest[res.BindToDevice] + if !ok { + continue + } + counter, ok := destToCounter[dst] + if !ok { + continue + } + counter.DropRef(res.Flags.Bits()) + if counter.TotalRefs() > 0 { + destToCounter[dst] = counter + continue + } + delete(destToCounter, dst) + if len(destToCounter) > 0 { + continue + } + delete(devToDest, res.BindToDevice) + if len(devToDest) > 0 { + continue + } + delete(addrToDev, res.Addr) + if len(addrToDev) > 0 { + continue } + delete(pm.allocatedPorts, desc) } } // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in // both IPv4 and IPv6. -func (s *PortManager) PortRange() (uint16, uint16) { - s.ephemeralMu.RLock() - defer s.ephemeralMu.RUnlock() - return s.firstEphemeral, s.firstEphemeral + s.numEphemeral - 1 +func (pm *PortManager) PortRange() (uint16, uint16) { + pm.ephemeralMu.RLock() + defer pm.ephemeralMu.RUnlock() + return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1 } // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range // (inclusive). -func (s *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error { +func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error { if start > end { return &tcpip.ErrInvalidPortRange{} } - s.ephemeralMu.Lock() - defer s.ephemeralMu.Unlock() - s.firstEphemeral = start - s.numEphemeral = end - start + 1 + pm.ephemeralMu.Lock() + defer pm.ephemeralMu.Unlock() + pm.firstEphemeral = start + pm.numEphemeral = end - start + 1 return nil } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 6cfac04b1..0f43dc8f8 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -331,15 +331,33 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { first, _ := pm.PortRange() if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) + portRes := Reservation{ + Networks: net, + Transport: fakeTransNumber, + Addr: test.ip, + Port: test.port, + Flags: test.flags, + BindToDevice: test.device, + Dest: test.dest, + } + pm.ReleasePort(portRes) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) + portRes := Reservation{ + Networks: net, + Transport: fakeTransNumber, + Addr: test.ip, + Port: test.port, + Flags: test.flags, + BindToDevice: test.device, + Dest: test.dest, + } + gotPort, err := pm.ReservePort(portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { - t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) + t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } if test.port == 0 && (gotPort == 0 || gotPort < first) { - t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, first) + t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first) } } }) |