diff options
Diffstat (limited to 'pkg/tcpip/ports')
-rw-r--r-- | pkg/tcpip/ports/ports.go | 206 | ||||
-rw-r--r-- | pkg/tcpip/ports/ports_test.go | 58 |
2 files changed, 221 insertions, 43 deletions
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index edc29ad27..f6d592eb5 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -52,6 +52,9 @@ type Flags struct { // // LoadBalanced takes precidence over MostRecent. LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool } // Bits converts the Flags to their bitset form. @@ -63,6 +66,9 @@ func (f Flags) Bits() BitFlags { if f.LoadBalanced { rf |= LoadBalancedFlag } + if f.TupleOnly { + rf |= TupleOnlyFlag + } return rf } @@ -98,6 +104,9 @@ const ( // LoadBalancedFlag represents Flags.LoadBalanced. LoadBalancedFlag + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + // nextFlag is the value that the next added flag will have. // // It is used to calculate FlagMask below. It is also the number of @@ -106,6 +115,10 @@ const ( // FlagMask is a bit mask for BitFlags. FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag ) // ToFlags converts the bitset into a Flags struct. @@ -113,6 +126,7 @@ func (f BitFlags) ToFlags() Flags { return Flags{ MostRecent: f&MostRecentFlag != 0, LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, } } @@ -175,9 +189,54 @@ func (c FlagCounter) IntersectionRefs() BitFlags { return intersection } +type destination struct { + addr tcpip.Address + port uint16 +} + +func makeDestination(a tcpip.FullAddress) destination { + return destination{ + a.Addr, + a.Port, + } +} + +// portNode is never empty. When it has no elements, it is removed from the +// map that references it. +type portNode map[destination]FlagCounter + +// intersectionRefs calculates the intersection of flag bit values which affect +// the specified destination. +// +// If no destinations are present, all flag values are returned as there are no +// entries to limit possible flag values of a new entry. +// +// In addition to the intersection, the number of intersecting refs is +// returned. +func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { + intersection := FlagMask + var count int + + for d, f := range p { + if d == dst { + intersection &= f.IntersectionRefs() + count++ + continue + } + // Wildcard destinations affect all destinations for TupleOnly. + if d.addr == anyIPAddress || dst.addr == anyIPAddress { + // Only bitwise and the TupleOnlyFlag. + intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + count++ + } + } + + return intersection, count +} + // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]FlagCounter +type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all FlagCounters. If binding to a specific device, check @@ -186,17 +245,15 @@ type deviceNode map[tcpip.NICID]FlagCounter // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { flagBits := flags.Bits() if bindToDevice == 0 { - // Trying to binding all devices. - if flagBits == 0 { - // Can't bind because the (addr,port) is already bound. - return false - } intersection := FlagMask for _, p := range d { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) + if c == 0 { + continue + } intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -210,16 +267,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.IntersectionRefs() - if intersection&flagBits == 0 { + var c int + intersection, c = p.intersectionRefs(dst) + if c > 0 && intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) intersection &= i - if intersection&flagBits == 0 { + if c > 0 && intersection&flagBits == 0 { return false } } @@ -233,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -247,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -320,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice, dst) { return false } } @@ -342,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() + dst := makeDestination(dest) + // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } return port, nil @@ -357,15 +417,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { return false } + flagBits := flags.Bits() // Reserve port on all network protocols. @@ -381,9 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - n := d[bindToDevice] + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + n := p[dst] n.AddRef(flagBits) - d[bindToDevice] = n + p[dst] = n + d[bindToDevice] = p + } + + return true +} + +// ReserveTuple adds a port reservation for the tuple on all given protocol. +func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { + flagBits := flags.Bits() + dst := makeDestination(dest) + + s.mu.Lock() + defer s.mu.Unlock() + + // It is easier to undo the entire reservation, so if we find that the + // tuple can't be fully added, finish and undo the whole thing. + undo := false + + // Reserve port on all network protocols. + for _, network := range networks { + desc := portDescriptor{network, transport, port} + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) + s.allocatedPorts[desc] = m + } + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + + n := p[dst] + if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + // Tuple already exists. + undo = true + } + n.AddRef(flagBits) + p[dst] = n + d[bindToDevice] = p + } + + if undo { + // releasePortLocked decrements the counts (rather than setting + // them to zero), so it will undo the incorrect incrementing + // above. + s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + return false } return true @@ -391,12 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.Bits() + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) +} +func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -404,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n, ok := d[bindToDevice] + p, ok := d[bindToDevice] if !ok { continue } - n.refs[flagBits]-- - d[bindToDevice] = n - if n.TotalRefs() == 0 { - delete(d, bindToDevice) + n, ok := p[dst] + if !ok { + continue + } + n.DropRef(flags) + if n.TotalRefs() > 0 { + p[dst] = n + continue } - if len(d) == 0 { - delete(m, addr) + delete(p, dst) + if len(p) > 0 { + continue + } + delete(d, bindToDevice) + if len(d) > 0 { + continue } - if len(m) == 0 { - delete(s.allocatedPorts, desc) + delete(m, addr) + if len(m) > 0 { + continue } + delete(s.allocatedPorts, desc) } } } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index d6969d050..58db5868c 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -36,6 +36,7 @@ type portReserveTestAction struct { flags Flags release bool device tcpip.NICID + dest tcpip.FullAddress } func TestPortReservation(t *testing.T) { @@ -272,6 +273,54 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind two tuples with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind two tuples", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind wildcard, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard twice with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -280,19 +329,18 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) } } }) - } } |