diff options
author | Ian Gudger <igudger@google.com> | 2020-06-10 23:48:03 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-06-10 23:49:26 -0700 |
commit | a085e562d0592bccc99e9e0380706a8025f70d53 (patch) | |
tree | 5f7bbf4180c8c898a372760cd253579891c4cd7f /pkg/tcpip | |
parent | a87c74bc548b1eebc9a118fcc192d906b9fb2e39 (diff) |
Add support for SO_REUSEADDR to UDP sockets/endpoints.
On UDP sockets, SO_REUSEADDR allows multiple sockets to bind to the same
address, but only delivers packets to the most recently bound socket. This
differs from the behavior of SO_REUSEADDR on TCP sockets. SO_REUSEADDR for TCP
sockets will likely need an almost completely independent implementation.
SO_REUSEADDR has some odd interactions with the similar SO_REUSEPORT. These
interactions are tested fairly extensively and all but one particularly odd
one (that honestly seems like a bug) behave the same on gVisor and Linux.
PiperOrigin-RevId: 315844832
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/ports/ports.go | 116 | ||||
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 95 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 3 |
12 files changed, 183 insertions, 127 deletions
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b937cb84b..edc29ad27 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -54,17 +54,27 @@ type Flags struct { LoadBalanced bool } -func (f Flags) bits() reuseFlag { - var rf reuseFlag +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags if f.MostRecent { - rf |= mostRecentFlag + rf |= MostRecentFlag } if f.LoadBalanced { - rf |= loadBalancedFlag + rf |= LoadBalancedFlag } return rf } +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -78,56 +88,88 @@ type PortManager struct { hint uint32 } -type reuseFlag int +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 const ( - mostRecentFlag reuseFlag = 1 << iota - loadBalancedFlag + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. nextFlag - flagMask = nextFlag - 1 + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 ) -type portNode struct { - // refs stores the count for each possible flag combination. +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). refs [nextFlag]int } -func (p portNode) totalRefs() int { +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { var total int - for _, r := range p.refs { + for _, r := range c.refs { total += r } return total } -// flagRefs returns the number of references with all specified flags. -func (p portNode) flagRefs(flags reuseFlag) int { +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { var total int - for i, r := range p.refs { - if reuseFlag(i)&flags == flags { + for i, r := range c.refs { + if BitFlags(i)&flags == flags { total += r } } return total } -// allRefsHave returns if all references have all specified flags. -func (p portNode) allRefsHave(flags reuseFlag) bool { - for i, r := range p.refs { - if reuseFlag(i)&flags == flags && r > 0 { +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { return false } } return true } -// intersectionRefs returns the set of flags shared by all references. -func (p portNode) intersectionRefs() reuseFlag { - intersection := flagMask - for i, r := range p.refs { +// IntersectionRefs returns the set of flags shared by all references. +func (c FlagCounter) IntersectionRefs() BitFlags { + intersection := FlagMask + for i, r := range c.refs { if r > 0 { - intersection &= reuseFlag(i) + intersection &= BitFlags(i) } } return intersection @@ -135,26 +177,26 @@ func (p portNode) intersectionRefs() reuseFlag { // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]portNode +type deviceNode map[tcpip.NICID]FlagCounter // isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all portNodes. If binding to a specific device, check +// device, check against all FlagCounters. If binding to a specific device, check // against the unspecified device and the provided device. // // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { - flagBits := flags.bits() + flagBits := flags.Bits() if bindToDevice == 0 { // Trying to binding all devices. if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } - intersection := flagMask + intersection := FlagMask for _, p := range d { - i := p.intersectionRefs() + i := p.IntersectionRefs() intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -165,17 +207,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { return true } - intersection := flagMask + intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.intersectionRefs() + intersection = p.IntersectionRefs() if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.intersectionRefs() + i := p.IntersectionRefs() intersection &= i if intersection&flagBits == 0 { return false @@ -324,7 +366,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } - flagBits := flags.bits() + flagBits := flags.Bits() // Reserve port on all network protocols. for _, network := range networks { @@ -340,7 +382,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m[addr] = d } n := d[bindToDevice] - n.refs[flagBits]++ + n.AddRef(flagBits) d[bindToDevice] = n } @@ -353,7 +395,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.bits() + flagBits := flags.Bits() for _, network := range networks { desc := portDescriptor{network, transport, port} @@ -368,7 +410,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp } n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == [nextFlag]int{} { + if n.TotalRefs() == 0 { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index afca925ad..24f52b735 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -89,6 +89,7 @@ go_test( "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 648791302..a2190341c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1404,25 +1404,25 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() s.cleanupEndpoints[ep] = struct{}{} - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CompleteTransportEndpointCleanup removes the endpoint from the cleanup diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e09866405..118b449d5 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,7 +15,6 @@ package stack import ( - "container/heap" "fmt" "math/rand" @@ -23,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" ) type protocolIDs struct { @@ -43,14 +43,14 @@ type transportEndpoints struct { // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { return } delete(eps.endpoints, id) @@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -214,23 +214,22 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t demux: d, netProto: netProto, transProto: transProto, - reuse: reusePort, } epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, reusePort) + return multiPortEp.singleRegisterEndpoint(t, flags) } // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. -func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } - if multiPortEp.unregisterEndpoint(t) { + if multiPortEp.unregisterEndpoint(t, flags) { delete(epsByNIC.endpoints, bindToDevice) } return len(epsByNIC.endpoints) == 0 @@ -279,10 +278,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) return err } } @@ -290,35 +289,6 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } -type transportEndpointHeap []TransportEndpoint - -var _ heap.Interface = (*transportEndpointHeap)(nil) - -func (h *transportEndpointHeap) Len() int { - return len(*h) -} - -func (h *transportEndpointHeap) Less(i, j int) bool { - return (*h)[i].UniqueID() < (*h)[j].UniqueID() -} - -func (h *transportEndpointHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *transportEndpointHeap) Push(x interface{}) { - *h = append(*h, x.(TransportEndpoint)) -} - -func (h *transportEndpointHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - old[n-1] = nil - *h = old[:n-1] - return x -} - // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -334,9 +304,10 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpoints transportEndpointHeap - // reuse indicates if more than one endpoint is allowed. - reuse bool + // endpoints stores the transport endpoints in the order in which they + // were bound. This is required for UDP SO_REUSEADDR. + endpoints []TransportEndpoint + flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,6 +333,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } + if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + return mpep.endpoints[len(mpep.endpoints)-1] + } + payload := []byte{ byte(id.LocalPort), byte(id.LocalPort >> 8), @@ -401,40 +376,47 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() + bits := flags.Bits() + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if !ep.reuse || !reusePort { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { return tcpip.ErrPortInUse } } - heap.Push(&ep.endpoints, t) + ep.endpoints = append(ep.endpoints, t) + ep.flags.AddRef(bits) return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. -func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() defer ep.mu.Unlock() for i, endpoint := range ep.endpoints { if endpoint == t { - heap.Remove(&ep.endpoints, i) + copy(ep.endpoints[i:], ep.endpoints[i+1:]) + ep.endpoints[len(ep.endpoints)-1] = nil + ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] + + ep.flags.DropRef(flags.Bits()) break } } return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { - // TODO(eyalsoha): Why? - reusePort = false + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] @@ -454,15 +436,20 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.endpoints[id] = epsByNIC } - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep, bindToDevice) + eps.unregisterEndpoint(id, ep, flags, bindToDevice) } } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 67d778137..73dada928 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ad61c09d6..7e8b84867 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -154,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -199,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, /* reuse */ - 0, /* bindtoDevice */ + ports.Flags{}, + 0, /* bindtoDevice */ ); err != nil { return err } diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9ce625c17..7e5c79776 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 57e0a069b..8ce294002 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,7 +111,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -607,14 +608,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e6a23c978..ad197e8db 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -238,13 +239,14 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.mu.Lock() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() return nil, err } n.isRegistered = true + n.registeredReusePort = n.reusePort return n, nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 19f7bf449..6e4d607da 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -465,6 +465,10 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // registeredReusePort is set if the current endpoint registration was + // done with SO_REUSEPORT enabled. + registeredReusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -1021,8 +1025,9 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) e.isRegistered = false + e.registeredReusePort = false } e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) @@ -1086,8 +1091,9 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) e.isRegistered = false + e.registeredReusePort = false } if e.isPortReserved { @@ -2088,10 +2094,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) if err != nil { return err } + e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2123,12 +2130,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { case nil: // Port picking successful. Save the details of // the selected port. e.ID = id e.boundBindToDevice = e.bindToDevice + e.registeredReusePort = e.reusePort return true, nil case tcpip.ErrPortInUse: return false, nil @@ -2326,12 +2334,13 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) + e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c5e3c73ef..df5efbf6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -103,7 +103,7 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool - reusePort bool + portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool @@ -214,7 +214,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -558,10 +558,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Unlock() case tcpip.ReuseAddressOption: + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() case tcpip.ReusePortOption: e.mu.Lock() - e.reusePort = v + e.portFlags.LoadBalanced = v e.mu.Unlock() case tcpip.V6OnlyOption: @@ -795,11 +798,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil case tcpip.ReuseAddressOption: - return false, nil + e.mu.RLock() + v := e.portFlags.MostRecent + e.mu.RUnlock() + + return v, nil case tcpip.ReusePortOption: e.mu.RLock() - v := e.reusePort + v := e.portFlags.LoadBalanced e.mu.RUnlock() return v, nil @@ -968,6 +975,11 @@ func (e *endpoint) Disconnect() *tcpip.Error { id stack.TransportEndpointID btd tcpip.NICID ) + + // We change this value below and we need the old value to unregister + // the endpoint. + boundPortFlags := e.boundPortFlags + // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -980,16 +992,17 @@ func (e *endpoint) Disconnect() *tcpip.Error { return err } e.state = StateBound + boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) e.boundPortFlags = ports.Flags{} } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -1061,6 +1074,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } + oldPortFlags := e.boundPortFlags + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err @@ -1068,7 +1083,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id @@ -1132,20 +1147,15 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - flags := ports.Flags{ - LoadBalanced: e.reusePort, - // FIXME(b/129164367): Support SO_REUSEADDR. - MostRecent: false, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) if err != nil { return id, e.bindToDevice, err } - e.boundPortFlags = flags id.LocalPort = port } + e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) e.boundPortFlags = ports.Flags{} diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 7abfa0ed2..c67e0ba95 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() return nil, err } @@ -82,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort ep.RegisterNICID = r.route.NICID() + ep.boundPortFlags = ep.portFlags ep.state = StateConnected |