summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp/endpoint.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp/endpoint.go')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go35
1 files changed, 14 insertions, 21 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 283379a28..d091a6196 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -132,6 +132,7 @@ func (e *endpoint) Close() {
switch e.state {
case stateBound, stateConnected:
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
// Close the receive list and drain it.
@@ -496,7 +497,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer e.mu.Unlock()
nicid := addr.NIC
- localPort := uint16(0)
+ var localPort uint16
switch e.state {
case stateInitial:
case stateBound, stateConnected:
@@ -537,7 +538,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if e.netProto == header.IPv6ProtocolNumber && !e.v6only {
+ if netProto == header.IPv6ProtocolNumber && !e.v6only {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
@@ -611,27 +612,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
- if id.LocalPort != 0 {
- // The endpoint already has a local port, just attempt to
- // register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
- return id, err
- }
-
- // We need to find a port for the endpoint.
- _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
- switch err {
- case nil:
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
- return false, err
+ if e.id.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ if err != nil {
+ return id, err
}
- })
+ id.LocalPort = port
+ }
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ }
return id, err
}
@@ -677,6 +669,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
if err := commit(); err != nil {
// Unregister, the commit failed.
e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
return err
}
}