diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 60 |
1 files changed, 29 insertions, 31 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 21a4b6e2f..9df22ac84 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2169,7 +2169,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2207,7 +2207,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } @@ -2505,47 +2505,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) - if err != nil { - return err - } - - e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = e.portFlags - e.isPortReserved = true - e.effectiveNetProtos = netProtos - e.ID.LocalPort = port - - // Any failures beyond this point must remove the port registration. - defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { - if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) - e.isPortReserved = false - e.effectiveNetProtos = nil - e.ID.LocalPort = 0 - e.ID.LocalAddress = "" - e.boundNICID = 0 - e.boundBindToDevice = 0 - e.boundPortFlags = ports.Flags{} - } - }(e.boundPortFlags, e.boundBindToDevice) - + var nic tcpip.NICID // If an address is specified, we must ensure that it's one of our // local addresses. if len(addr.Addr) != 0 { - nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { return tcpip.ErrBadLocalAddress } - - e.boundNICID = nic e.ID.LocalAddress = addr.Addr } - if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + id := e.ID + id.LocalPort = p + // CheckRegisterTransportEndpoint should only return an error if there is a + // listening endpoint bound with the same id and portFlags and bindToDevice + // options. + // + // NOTE: Only listening and connected endpoint register with + // demuxer. Further connected endpoints always have a remote + // address/port. Hence this will only return an error if there is a matching + // listening endpoint. + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + return false + } + return true + }) + if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. + e.boundNICID = nic + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.ID.LocalPort = port + // Mark endpoint as bound. e.setEndpointState(StateBound) |