summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-09-14 18:35:55 -0700
committergVisor bot <gvisor-bot@google.com>2021-09-14 18:38:46 -0700
commit0bec34a8e29e1099827a025009e08aa190ff4441 (patch)
tree6e87a5813e8e55b6d171914f65c9931e2ed43104 /pkg
parent39470428dd3c5fef966cee138ae00ab8b9059983 (diff)
Compose raw IP with datagram-based endpoint
A raw IP endpoint's write and socket option get/set path can use the datagram-based endpoint. This change extracts tests from UDP that may also run on Raw IP sockets. Updates #6565. Test: Raw IP + datagram-based socket syscall tests. PiperOrigin-RevId: 396729727
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/transport/internal/network/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/BUILD2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go295
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go30
4 files changed, 97 insertions, 231 deletions
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD
index d6d3f52a3..b1edce39b 100644
--- a/pkg/tcpip/transport/internal/network/BUILD
+++ b/pkg/tcpip/transport/internal/network/BUILD
@@ -9,6 +9,7 @@ go_library(
"endpoint_state.go",
],
visibility = [
+ "//pkg/tcpip/transport/raw:__pkg__",
"//pkg/tcpip/transport/udp:__pkg__",
],
deps = [
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 2eab09088..b7e97e218 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -33,6 +33,8 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/tcpip/transport",
+ "//pkg/tcpip/transport/internal/network",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 264d29c7a..3040a445b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,6 +26,7 @@
package raw
import (
+ "fmt"
"io"
"time"
@@ -34,6 +35,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -57,15 +60,19 @@ type rawPacket struct {
//
// +stateify savable
type endpoint struct {
- stack.TransportEndpointInfo
tcpip.DefaultSocketOptionsHandler
// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
+ transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
+ net network.Endpoint
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ ops tcpip.SocketOptions
+
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
@@ -74,20 +81,7 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- closed bool
- connected bool
- bound bool
- // route is the route to a remote network endpoint. It is set via
- // Connect(), and is valid only when conneted is true.
- route *stack.Route `state:"manual"`
- stats tcpip.TransportEndpointStats `state:"nosave"`
- // owner is used to get uid and gid of the packet.
- owner tcpip.PacketOwner
-
- // ops is used to get socket level options.
- ops tcpip.SocketOptions
-
+ mu sync.RWMutex `state:"nosave"`
// frozen indicates if the packets should be delivered to the endpoint
// during restore.
frozen bool
@@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
- return nil, &tcpip.ErrUnknownProtocol{}
- }
-
e := &endpoint{
- stack: s,
- TransportEndpointInfo: stack.TransportEndpointInfo{
- NetProto: netProto,
- TransProto: transProto,
- },
+ stack: s,
+ transProto: transProto,
waiterQueue: waiterQueue,
associated: associated,
}
@@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
e.ops.SetHeaderIncluded(!associated)
e.ops.SetSendBufferSize(32*1024, false /* notify */)
e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
+ e.net.Init(s, netProto, transProto, &e.ops)
// Override with stack defaults.
var ss tcpip.SendBufferSizeOption
@@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return nil, err
}
@@ -154,11 +142,17 @@ func (e *endpoint) Close() {
e.mu.Lock()
defer e.mu.Unlock()
- if e.closed || !e.associated {
+ if e.net.State() == transport.DatagramEndpointStateClosed {
+ return
+ }
+
+ e.net.Close()
+
+ if !e.associated {
return
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
+ e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e)
e.rcvMu.Lock()
defer e.rcvMu.Unlock()
@@ -170,15 +164,6 @@ func (e *endpoint) Close() {
e.rcvList.Remove(e.rcvList.Front())
}
- e.connected = false
-
- if e.route != nil {
- e.route.Release()
- e.route = nil
- }
-
- e.closed = true
-
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
@@ -186,9 +171,7 @@ func (e *endpoint) Close() {
func (*endpoint) ModerateRecvBuf(int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
- e.mu.Lock()
- defer e.mu.Unlock()
- e.owner = owner
+ e.net.SetOwner(owner)
}
// Read implements tcpip.Endpoint.Read.
@@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
+ netProto := e.net.NetProto()
// We can create, but not write to, unassociated IPv6 endpoints.
- if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ if !e.associated && netProto == header.IPv6ProtocolNumber {
return 0, &tcpip.ErrInvalidOptionValue{}
}
if opts.To != nil {
// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize {
return 0, &tcpip.ErrInvalidOptionValue{}
}
}
@@ -269,79 +253,26 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
- // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
- if opts.More {
- return 0, &tcpip.ErrInvalidOptionValue{}
- }
- payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- if e.closed {
- return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
- }
-
- // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
- payloadBytes := make([]byte, p.Len())
- if _, err := io.ReadFull(p, payloadBytes); err != nil {
- return nil, nil, nil, &tcpip.ErrBadBuffer{}
- }
-
- // Did the user caller provide a destination? If not, use the connected
- // destination.
- if opts.To == nil {
- // If the user doesn't specify a destination, they should have
- // connected to another address.
- if !e.connected {
- return nil, nil, nil, &tcpip.ErrDestinationRequired{}
- }
-
- e.route.Acquire()
-
- return payloadBytes, e.route, e.owner, nil
- }
-
- // The caller provided a destination. Reject destination address if it
- // goes through a different NIC than the endpoint was bound to.
- nic := opts.To.NIC
- if e.bound && nic != 0 && nic != e.BindNICID {
- return nil, nil, nil, &tcpip.ErrNoRoute{}
- }
-
- // Find the route to the destination. If BindAddress is 0,
- // FindRoute will choose an appropriate source address.
- route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
- if err != nil {
- return nil, nil, nil, err
- }
-
- return payloadBytes, route, e.owner, nil
- }()
+ ctx, err := e.net.AcquireContextForWrite(opts)
if err != nil {
return 0, err
}
- defer route.Release()
+
+ // TODO(https://gvisor.dev/issue/6538): Avoid this allocation.
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, &tcpip.ErrBadBuffer{}
+ }
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength),
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
- pkt.Owner = owner
- if e.ops.GetHeaderIncluded() {
- if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
- return 0, err
- }
- return int64(len(payloadBytes)), nil
- }
-
- if err := route.WritePacket(stack.NetworkHeaderParams{
- Protocol: e.TransProto,
- TTL: route.DefaultTTL(),
- TOS: stack.DefaultTOS,
- }, pkt); err != nil {
+ if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil {
return 0, err
}
+
return int64(len(payloadBytes)), nil
}
@@ -352,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error {
// Connect implements tcpip.Endpoint.Connect.
func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
+ netProto := e.net.NetProto()
+
// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
- if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
+ if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize {
return &tcpip.ErrAddressFamilyNotSupported{}
}
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if e.closed {
- return &tcpip.ErrInvalidEndpointState{}
- }
-
- nic := addr.NIC
- if e.bound {
- if e.BindNICID == 0 {
- // If we're bound, but not to a specific NIC, the NIC
- // in addr will be used. Nothing to do here.
- } else if addr.NIC == 0 {
- // If we're bound to a specific NIC, but addr doesn't
- // specify a NIC, use the bound NIC.
- nic = e.BindNICID
- } else if addr.NIC != e.BindNICID {
- // We're bound and addr specifies a NIC. They must be
- // the same.
- return &tcpip.ErrInvalidEndpointState{}
- }
- }
-
- // Find a route to the destination.
- route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false)
- if err != nil {
- return err
- }
-
- if e.associated {
- // Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- route.Release()
- return err
+ return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
+ if e.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ return err
+ }
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = nic
- }
- if e.route != nil {
- // If the endpoint was previously connected then release any previous route.
- e.route.Release()
- }
- e.route = route
- e.connected = true
-
- return nil
+ return nil
+ })
}
// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- if !e.connected {
+ if e.net.State() != transport.DatagramEndpointStateConnected {
return &tcpip.ErrNotConnected{}
}
return nil
@@ -429,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
// Bind implements tcpip.Endpoint.Bind.
func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // If a local address was specified, verify that it's valid.
- if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 {
- return &tcpip.ErrBadLocalAddress{}
- }
+ return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error {
+ if !e.associated {
+ return nil
+ }
- if e.associated {
// Re-register the endpoint with the appropriate NIC.
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
return err
}
- e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e)
- e.RegisterNICID = addr.NIC
- e.BindNICID = addr.NIC
- }
-
- e.BindAddr = addr.Addr
- e.bound = true
-
- return nil
+ e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
+ return nil
+ })
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- addr := e.BindAddr
- if e.connected {
- addr = e.route.LocalAddress()
- }
-
- return tcpip.FullAddress{
- NIC: e.RegisterNICID,
- Addr: addr,
- // Linux returns the protocol in the port field.
- Port: uint16(e.TransProto),
- }, nil
+ a := e.net.GetLocalAddress()
+ // Linux returns the protocol in the port field.
+ a.Port = uint16(e.transProto)
+ return a, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
@@ -501,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
default:
- return &tcpip.ErrUnknownProtocolOption{}
+ return e.net.SetSockOpt(opt)
}
}
-func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
+ return e.net.SetSockOptInt(opt, v)
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
- return &tcpip.ErrUnknownProtocolOption{}
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
+ return e.net.GetSockOpt(opt)
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
@@ -528,7 +402,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
return v, nil
default:
- return -1, &tcpip.ErrUnknownProtocolOption{}
+ return e.net.GetSockOptInt(opt)
}
}
@@ -561,23 +435,33 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
return false
}
- if e.bound {
+ srcAddr := pkt.Network().SourceAddress()
+ info := e.net.Info()
+
+ switch state := e.net.State(); state {
+ case transport.DatagramEndpointStateInitial:
+ case transport.DatagramEndpointStateConnected:
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if info.ID.RemoteAddress != srcAddr {
+ return false
+ }
+
+ // Connected sockets may also have been bound to a specific
+ // address/NIC.
+ fallthrough
+ case transport.DatagramEndpointStateBound:
// If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
+ if info.BindNICID != 0 && info.BindNICID != pkt.NICID {
return false
}
// If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() {
+ if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() {
return false
}
- }
-
- srcAddr := pkt.Network().SourceAddress()
- // If connected, only accept packets from the remote address we
- // connected to.
- if e.connected && e.route.RemoteAddress() != srcAddr {
- return false
+ default:
+ panic(fmt.Sprintf("unhandled state = %s", state))
}
wasEmpty := e.rcvBufSize == 0
@@ -598,7 +482,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
// TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
// overlapping slices.
var combinedVV buffer.VectorisedView
- if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ if info.NetProto == header.IPv4ProtocolNumber {
network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
headers := make(buffer.View, 0, len(network)+len(transport))
headers = append(headers, network...)
@@ -631,10 +515,7 @@ func (e *endpoint) State() uint32 {
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
- // Make a copy of the endpoint info.
- ret := e.TransportEndpointInfo
- e.mu.RUnlock()
+ ret := e.net.Info()
return &ret
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 39669b445..e74713064 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -15,6 +15,7 @@
package raw
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.net.Resume(s)
+
e.thaw()
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
- // If the endpoint is connected, re-connect.
- if e.connected {
- var err tcpip.Error
- // TODO(gvisor.dev/issue/4906): Properly restore the route with the right
- // remote address. We used to pass e.remote.RemoteAddress which was
- // effectively the empty address but since moving e.route to hold a pointer
- // to a route instead of the route by value, we pass the empty address
- // directly. Obviously this was always wrong since we should provide the
- // remote address we were connected to, to properly restore the route.
- e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false)
- if err != nil {
- panic(err)
- }
- }
-
- // If the endpoint is bound, re-bind.
- if e.bound {
- if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
- panic(&tcpip.ErrBadLocalAddress{})
- }
- }
-
if e.associated {
- if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil {
- panic(err)
+ netProto := e.net.NetProto()
+ if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
+ panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
}
}
}