diff options
Diffstat (limited to 'pkg/tcpip/transport/udp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 111 |
1 files changed, 62 insertions, 49 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index ac5905772..0bec7e62d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,7 +15,6 @@ package udp import ( - "math" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -37,13 +36,17 @@ type udpPacket struct { views [8]buffer.View `state:"nosave"` } -type endpointState int +// EndpointState represents the state of a UDP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed + StateInitial EndpointState = iota + StateBound + StateConnected + StateClosed ) // endpoint represents a UDP endpoint. This struct serves as the interface @@ -74,7 +77,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` sndBufSize int id stack.TransportEndpointID - state endpointState + state EndpointState bindNICID tcpip.NICID regNICID tcpip.NICID route stack.Route `state:"manual"` @@ -140,7 +143,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { - case stateBound, stateConnected: + case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } @@ -163,7 +166,7 @@ func (e *endpoint) Close() { e.route.Release() // Update the state. - e.state = stateClosed + e.state = StateClosed e.mu.Unlock() @@ -211,11 +214,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Returns true for retry if preparation should be retried. func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { switch e.state { - case stateInitial: - case stateConnected: + case StateInitial: + case StateConnected: return false, nil - case stateBound: + case StateBound: if to == nil { return false, tcpip.ErrDestinationRequired } @@ -232,7 +235,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.state != StateInitial { return true, nil } @@ -273,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue } - if p.Size() > math.MaxUint16 { - // Payload can't possibly fit in a packet. - return 0, nil, tcpip.ErrMessageTooLong - } - to := opts.To e.mu.RLock() @@ -322,7 +320,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha defer e.mu.Unlock() // Recheck state after lock was re-acquired. - if e.state != stateConnected { + if e.state != StateConnected { return 0, nil, tcpip.ErrInvalidEndpointState } } @@ -366,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } ttl := route.DefaultTTL() if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { @@ -387,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { switch v := opt.(type) { case tcpip.V6OnlyOption: @@ -400,7 +407,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -566,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil } + return -1, tcpip.ErrUnknownProtocolOption } @@ -576,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() - return nil - - case *tcpip.ReceiveBufferSizeOption: - e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) - e.rcvMu.Unlock() - return nil - case *tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.netProto != header.IPv6ProtocolNumber { @@ -726,7 +734,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != stateConnected { + if e.state != StateConnected { return nil } id := stack.TransportEndpointID{} @@ -741,9 +749,13 @@ func (e *endpoint) Disconnect() *tcpip.Error { if err != nil { return err } - e.state = stateBound + e.state = StateBound } else { - e.state = stateInitial + if e.id.LocalPort != 0 { + // Release the ephemeral port. + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + } + e.state = StateInitial } e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) @@ -772,8 +784,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicid := addr.NIC var localPort uint16 switch e.state { - case stateInitial: - case stateBound, stateConnected: + case StateInitial: + case StateBound, StateConnected: localPort = e.id.LocalPort if e.bindNICID == 0 { break @@ -801,7 +813,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { RemoteAddress: r.RemoteAddress, } - if e.state == stateInitial { + if e.state == StateInitial { id.LocalAddress = r.LocalAddress } @@ -832,7 +844,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.regNICID = nicid e.effectiveNetProtos = netProtos - e.state = stateConnected + e.state = StateConnected e.rcvMu.Lock() e.rcvReady = true @@ -854,7 +866,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. - if e.state != stateBound && e.state != stateConnected { + if e.state != StateBound && e.state != StateConnected { return tcpip.ErrNotConnected } @@ -903,7 +915,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -946,7 +958,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { e.effectiveNetProtos = netProtos // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound e.rcvMu.Lock() e.rcvReady = true @@ -989,7 +1001,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if e.state != StateConnected { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1069,10 +1081,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } -// State implements socket.Socket.State. +// State implements tcpip.Endpoint.State. func (e *endpoint) State() uint32 { - // TODO(b/112063468): Translate internal state to values returned by Linux. - return 0 + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) } func isBroadcastOrMulticast(a tcpip.Address) bool { |