summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go66
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go5
-rw-r--r--test/syscalls/linux/udp_socket.cc31
3 files changed, 71 insertions, 31 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 776c1af43..912d33da8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,6 +16,7 @@ package udp
import (
"fmt"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -95,9 +96,11 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- sndBufSizeMax int
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState
route *stack.Route `state:"manual"`
dstPort uint16
@@ -198,6 +201,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
return e
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState() returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
// UniqueID implements stack.TransportEndpoint.UniqueID.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
@@ -223,7 +240,7 @@ func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
- switch e.state {
+ switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
@@ -252,7 +269,7 @@ func (e *endpoint) Close() {
}
// Update the state.
- e.state = StateClosed
+ e.setEndpointState(StateClosed)
e.mu.Unlock()
@@ -316,7 +333,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
//
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateConnected:
return false, nil
@@ -338,7 +355,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return true, nil
}
@@ -453,7 +470,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.mu.Lock()
// Recheck state after lock was re-acquired.
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
err = tcpip.ErrInvalidEndpointState
}
if err == nil && route.IsResolutionRequired() {
@@ -464,7 +481,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
e.mu.RLock()
// Recheck state after lock was re-acquired.
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
err = tcpip.ErrInvalidEndpointState
}
return ch, err
@@ -934,7 +951,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return nil
}
var (
@@ -957,7 +974,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = StateBound
+ e.setEndpointState(StateBound)
boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
@@ -965,7 +982,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- e.state = StateInitial
+ e.setEndpointState(StateInitial)
}
e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
@@ -990,7 +1007,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicID := addr.NIC
var localPort uint16
- switch e.state {
+ switch e.EndpointState() {
case StateInitial:
case StateBound, StateConnected:
localPort = e.ID.LocalPort
@@ -1025,7 +1042,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -1059,7 +1076,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
- e.state = StateConnected
+ e.setEndpointState(StateConnected)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1081,7 +1098,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != StateBound && e.state != StateConnected {
+ if state := e.EndpointState(); state != StateBound && state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -1132,7 +1149,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1176,7 +1193,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
e.rcvMu.Lock()
e.rcvReady = true
@@ -1208,7 +1225,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
defer e.mu.RUnlock()
addr := e.ID.LocalAddress
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
addr = e.route.LocalAddress
}
@@ -1224,7 +1241,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != StateConnected {
+ if e.EndpointState() != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1356,25 +1373,20 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
- e.mu.RLock()
- if e.state == StateConnected {
+ if e.EndpointState() == StateConnected {
e.lastErrorMu.Lock()
e.lastError = tcpip.ErrConnectionRefused
e.lastErrorMu.Unlock()
- e.mu.RUnlock()
e.waiterQueue.Notify(waiter.EventErr)
return
}
- e.mu.RUnlock()
}
}
// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 9d06035ea..13b72dc88 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != StateBound && e.state != StateConnected {
+ state := e.EndpointState()
+ if state != StateBound && state != StateConnected {
return
}
@@ -113,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == StateConnected {
+ if state == StateConnected {
e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop())
if err != nil {
panic(err)
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 34255bfb8..90ef8bf21 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -375,8 +375,6 @@ TEST_P(UdpSocketTest, BindInUse) {
}
TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
- ASSERT_NO_ERRNO(BindLoopback());
-
// Discover a free unused port by creating a new UDP socket, binding it
// recording the just bound port and closing it. This is not guaranteed as it
// can still race with other port UDP sockets trying to bind a port at the
@@ -410,6 +408,35 @@ TEST_P(UdpSocketTest, ConnectWriteToInvalidPort) {
ASSERT_EQ(optlen, sizeof(err));
}
+TEST_P(UdpSocketTest, ConnectSimultaneousWriteToInvalidPort) {
+ // Discover a free unused port by creating a new UDP socket, binding it
+ // recording the just bound port and closing it. This is not guaranteed as it
+ // can still race with other port UDP sockets trying to bind a port at the
+ // same time.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ socklen_t addrlen = sizeof(addr_storage);
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_THAT(bind(s.get(), addr, addrlen), SyscallSucceeds());
+ ASSERT_THAT(getsockname(s.get(), addr, &addrlen), SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr_storage), 0);
+ ASSERT_THAT(close(s.release()), SyscallSucceeds());
+
+ // Now connect to the port that we just released.
+ ScopedThread t([&] {
+ ASSERT_THAT(connect(sock_.get(), addr, addrlen_), SyscallSucceeds());
+ });
+
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+ // Send from sock_ to an unbound port.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ t.Join();
+}
+
TEST_P(UdpSocketTest, ReceiveAfterConnect) {
ASSERT_NO_ERRNO(BindLoopback());
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());