summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/endpoint.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go60
1 files changed, 29 insertions, 31 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 21a4b6e2f..9df22ac84 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2169,7 +2169,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
if err != tcpip.ErrPortInUse || !reuse {
return false, nil
}
@@ -2207,7 +2207,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2505,47 +2505,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
- if err != nil {
- return err
- }
-
- e.boundBindToDevice = e.bindToDevice
- e.boundPortFlags = e.portFlags
- e.isPortReserved = true
- e.effectiveNetProtos = netProtos
- e.ID.LocalPort = port
-
- // Any failures beyond this point must remove the port registration.
- defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
- if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
- e.isPortReserved = false
- e.effectiveNetProtos = nil
- e.ID.LocalPort = 0
- e.ID.LocalAddress = ""
- e.boundNICID = 0
- e.boundBindToDevice = 0
- e.boundPortFlags = ports.Flags{}
- }
- }(e.boundPortFlags, e.boundBindToDevice)
-
+ var nic tcpip.NICID
// If an address is specified, we must ensure that it's one of our
// local addresses.
if len(addr.Addr) != 0 {
- nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 {
return tcpip.ErrBadLocalAddress
}
-
- e.boundNICID = nic
e.ID.LocalAddress = addr.Addr
}
- if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ id := e.ID
+ id.LocalPort = p
+ // CheckRegisterTransportEndpoint should only return an error if there is a
+ // listening endpoint bound with the same id and portFlags and bindToDevice
+ // options.
+ //
+ // NOTE: Only listening and connected endpoint register with
+ // demuxer. Further connected endpoints always have a remote
+ // address/port. Hence this will only return an error if there is a matching
+ // listening endpoint.
+ if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil {
+ return false
+ }
+ return true
+ })
+ if err != nil {
return err
}
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct.
+ e.boundNICID = nic
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.ID.LocalPort = port
+
// Mark endpoint as bound.
e.setEndpointState(StateBound)