diff options
author | Tamir Duberstein <tamird@google.com> | 2018-09-12 09:37:57 -0700 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-09-12 09:39:01 -0700 |
commit | cbf39804647eabafb6138714ed222dbdc4781f2e (patch) | |
tree | aee396f3870f863d2edb72cf125fe83d08255d8e /pkg/tcpip/transport/udp/endpoint.go | |
parent | b4aed01bf227bfc0b29ce3100858366f60c0647b (diff) |
Prevent UDP sockets from binding to bound ports
PiperOrigin-RevId: 212653818
Change-Id: Ib4e1d754d9cdddeaa428a066cb675e6ec44d91ad
Diffstat (limited to 'pkg/tcpip/transport/udp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 35 |
1 files changed, 14 insertions, 21 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 283379a28..d091a6196 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -132,6 +132,7 @@ func (e *endpoint) Close() { switch e.state { case stateBound, stateConnected: e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } // Close the receive list and drain it. @@ -496,7 +497,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer e.mu.Unlock() nicid := addr.NIC - localPort := uint16(0) + var localPort uint16 switch e.state { case stateInitial: case stateBound, stateConnected: @@ -537,7 +538,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if e.netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.v6only { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -611,27 +612,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { } func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { - if id.LocalPort != 0 { - // The endpoint already has a local port, just attempt to - // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) - return id, err - } - - // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) - switch err { - case nil: - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: - return false, err + if e.id.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + if err != nil { + return id, err } - }) + id.LocalPort = port + } + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + } return id, err } @@ -677,6 +669,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error if err := commit(); err != nil { // Unregister, the commit failed. e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) return err } } |