diff options
-rw-r--r-- | pkg/tcpip/transport/internal/network/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 295 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint_state.go | 30 | ||||
-rw-r--r-- | test/syscalls/BUILD | 5 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 38 | ||||
-rw-r--r-- | test/syscalls/linux/ip_socket_test_util.cc | 8 | ||||
-rw-r--r-- | test/syscalls/linux/ip_socket_test_util.h | 4 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc | 299 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h | 31 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc | 28 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_udp_unbound.cc | 272 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc | 4 |
13 files changed, 505 insertions, 512 deletions
diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index d6d3f52a3..b1edce39b 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], deps = [ diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2eab09088..b7e97e218 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/packet", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 264d29c7a..3040a445b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,6 +26,7 @@ package raw import ( + "fmt" "io" "time" @@ -34,6 +35,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -57,15 +60,19 @@ type rawPacket struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue associated bool + net network.Endpoint + stats tcpip.TransportEndpointStats `state:"nosave"` + ops tcpip.SocketOptions + // The following fields are used to manage the receive queue and are // protected by rcvMu. rcvMu sync.Mutex `state:"nosave"` @@ -74,20 +81,7 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - closed bool - connected bool - bound bool - // route is the route to a remote network endpoint. It is set via - // Connect(), and is valid only when conneted is true. - route *stack.Route `state:"manual"` - stats tcpip.TransportEndpointStats `state:"nosave"` - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - + mu sync.RWMutex `state:"nosave"` // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool @@ -99,16 +93,9 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) { - if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { - return nil, &tcpip.ErrUnknownProtocol{} - } - e := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, associated: associated, } @@ -116,6 +103,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt e.ops.SetHeaderIncluded(!associated) e.ops.SetSendBufferSize(32*1024, false /* notify */) e.ops.SetReceiveBufferSize(32*1024, false /* notify */) + e.net.Init(s, netProto, transProto, &e.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -137,7 +125,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { return nil, err } @@ -154,11 +142,17 @@ func (e *endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.closed || !e.associated { + if e.net.State() == transport.DatagramEndpointStateClosed { + return + } + + e.net.Close() + + if !e.associated { return } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) + e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e) e.rcvMu.Lock() defer e.rcvMu.Unlock() @@ -170,15 +164,6 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - e.connected = false - - if e.route != nil { - e.route.Release() - e.route = nil - } - - e.closed = true - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } @@ -186,9 +171,7 @@ func (e *endpoint) Close() { func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.mu.Lock() - defer e.mu.Unlock() - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -236,14 +219,15 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + netProto := e.net.NetProto() // We can create, but not write to, unassociated IPv6 endpoints. - if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + if !e.associated && netProto == header.IPv6ProtocolNumber { return 0, &tcpip.ErrInvalidOptionValue{} } if opts.To != nil { // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. - if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + if netProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { return 0, &tcpip.ErrInvalidOptionValue{} } } @@ -269,79 +253,26 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - if e.closed { - return nil, nil, nil, &tcpip.ErrInvalidEndpointState{} - } - - // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. - payloadBytes := make([]byte, p.Len()) - if _, err := io.ReadFull(p, payloadBytes); err != nil { - return nil, nil, nil, &tcpip.ErrBadBuffer{} - } - - // Did the user caller provide a destination? If not, use the connected - // destination. - if opts.To == nil { - // If the user doesn't specify a destination, they should have - // connected to another address. - if !e.connected { - return nil, nil, nil, &tcpip.ErrDestinationRequired{} - } - - e.route.Acquire() - - return payloadBytes, e.route, e.owner, nil - } - - // The caller provided a destination. Reject destination address if it - // goes through a different NIC than the endpoint was bound to. - nic := opts.To.NIC - if e.bound && nic != 0 && nic != e.BindNICID { - return nil, nil, nil, &tcpip.ErrNoRoute{} - } - - // Find the route to the destination. If BindAddress is 0, - // FindRoute will choose an appropriate source address. - route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) - if err != nil { - return nil, nil, nil, err - } - - return payloadBytes, route, e.owner, nil - }() + ctx, err := e.net.AcquireContextForWrite(opts) if err != nil { return 0, err } - defer route.Release() + + // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, &tcpip.ErrBadBuffer{} + } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(route.MaxHeaderLength()), + ReserveHeaderBytes: int(ctx.PacketInfo().MaxHeaderLength), Data: buffer.View(payloadBytes).ToVectorisedView(), }) - pkt.Owner = owner - if e.ops.GetHeaderIncluded() { - if err := route.WriteHeaderIncludedPacket(pkt); err != nil { - return 0, err - } - return int64(len(payloadBytes)), nil - } - - if err := route.WritePacket(stack.NetworkHeaderParams{ - Protocol: e.TransProto, - TTL: route.DefaultTTL(), - TOS: stack.DefaultTOS, - }, pkt); err != nil { + if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil { return 0, err } + return int64(len(payloadBytes)), nil } @@ -352,66 +283,29 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { + netProto := e.net.NetProto() + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. - if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + if netProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { return &tcpip.ErrAddressFamilyNotSupported{} } - e.mu.Lock() - defer e.mu.Unlock() - - if e.closed { - return &tcpip.ErrInvalidEndpointState{} - } - - nic := addr.NIC - if e.bound { - if e.BindNICID == 0 { - // If we're bound, but not to a specific NIC, the NIC - // in addr will be used. Nothing to do here. - } else if addr.NIC == 0 { - // If we're bound to a specific NIC, but addr doesn't - // specify a NIC, use the bound NIC. - nic = e.BindNICID - } else if addr.NIC != e.BindNICID { - // We're bound and addr specifies a NIC. They must be - // the same. - return &tcpip.ErrInvalidEndpointState{} - } - } - - // Find a route to the destination. - route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) - if err != nil { - return err - } - - if e.associated { - // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { - route.Release() - return err + return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error { + if e.associated { + // Re-register the endpoint with the appropriate NIC. + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { + return err + } + e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e) } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) - e.RegisterNICID = nic - } - if e.route != nil { - // If the endpoint was previously connected then release any previous route. - e.route.Release() - } - e.route = route - e.connected = true - - return nil + return nil + }) } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - if !e.connected { + if e.net.State() != transport.DatagramEndpointStateConnected { return &tcpip.ErrNotConnected{} } return nil @@ -429,46 +323,26 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi // Bind implements tcpip.Endpoint.Bind. func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - // If a local address was specified, verify that it's valid. - if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} - } + return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error { + if !e.associated { + return nil + } - if e.associated { // Re-register the endpoint with the appropriate NIC. - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { return err } - e.stack.UnregisterRawTransportEndpoint(e.NetProto, e.TransProto, e) - e.RegisterNICID = addr.NIC - e.BindNICID = addr.NIC - } - - e.BindAddr = addr.Addr - e.bound = true - - return nil + e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e) + return nil + }) } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - addr := e.BindAddr - if e.connected { - addr = e.route.LocalAddress() - } - - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: addr, - // Linux returns the protocol in the port field. - Port: uint16(e.TransProto), - }, nil + a := e.net.GetLocalAddress() + // Linux returns the protocol in the port field. + a.Port = uint16(e.transProto) + return a, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. @@ -501,17 +375,17 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil default: - return &tcpip.ErrUnknownProtocolOption{} + return e.net.SetSockOpt(opt) } } -func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { + return e.net.SetSockOptInt(opt, v) } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -528,7 +402,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { return v, nil default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } @@ -561,23 +435,33 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return false } - if e.bound { + srcAddr := pkt.Network().SourceAddress() + info := e.net.Info() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: + // If connected, only accept packets from the remote address we + // connected to. + if info.ID.RemoteAddress != srcAddr { + return false + } + + // Connected sockets may also have been bound to a specific + // address/NIC. + fallthrough + case transport.DatagramEndpointStateBound: // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != pkt.NICID { + if info.BindNICID != 0 && info.BindNICID != pkt.NICID { return false } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() { + if info.BindAddr != "" && info.BindAddr != pkt.Network().DestinationAddress() { return false } - } - - srcAddr := pkt.Network().SourceAddress() - // If connected, only accept packets from the remote address we - // connected to. - if e.connected && e.route.RemoteAddress() != srcAddr { - return false + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } wasEmpty := e.rcvBufSize == 0 @@ -598,7 +482,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports // overlapping slices. var combinedVV buffer.VectorisedView - if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + if info.NetProto == header.IPv4ProtocolNumber { network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() headers := make(buffer.View, 0, len(network)+len(transport)) headers = append(headers, network...) @@ -631,10 +515,7 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { - e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + ret := e.net.Info() return &ret } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 39669b445..e74713064 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,6 +15,7 @@ package raw import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -60,35 +61,16 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.net.Resume(s) + e.thaw() e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - // If the endpoint is connected, re-connect. - if e.connected { - var err tcpip.Error - // TODO(gvisor.dev/issue/4906): Properly restore the route with the right - // remote address. We used to pass e.remote.RemoteAddress which was - // effectively the empty address but since moving e.route to hold a pointer - // to a route instead of the route by value, we pass the empty address - // directly. Obviously this was always wrong since we should provide the - // remote address we were connected to, to properly restore the route. - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) - if err != nil { - panic(err) - } - } - - // If the endpoint is bound, re-bind. - if e.bound { - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - if e.associated { - if err := e.stack.RegisterRawTransportEndpoint(e.NetProto, e.TransProto, e); err != nil { - panic(err) + netProto := e.net.NetProto() + if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil { + panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err)) } } } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 16c451786..3b55112e6 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -707,6 +707,11 @@ syscall_test( syscall_test( size = "medium", + test = "//test/syscalls/linux:socket_ipv4_datagram_based_socket_unbound_loopback_test", +) + +syscall_test( + size = "medium", test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 190135d6c..a4efd9486 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -26,6 +26,7 @@ exports_files( "socket_ip_udp_loopback_blocking.cc", "socket_ip_udp_loopback_nonblock.cc", "socket_ip_unbound.cc", + "socket_ipv4_datagram_based_socket_unbound_loopback.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", "socket_ipv6_udp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_loopback.cc", @@ -2534,6 +2535,24 @@ cc_library( ) cc_library( + name = "socket_ipv4_datagram_based_socket_unbound_test_cases", + testonly = 1, + srcs = [ + "socket_ipv4_datagram_based_socket_unbound.cc", + ], + hdrs = [ + "socket_ipv4_datagram_based_socket_unbound.h", + ], + deps = [ + ":ip_socket_test_util", + "//test/util:capability_util", + "//test/util:socket_util", + gtest, + ], + alwayslink = 1, +) + +cc_library( name = "socket_ipv4_udp_unbound_test_cases", testonly = 1, srcs = [ @@ -2547,9 +2566,6 @@ cc_library( "//test/util:socket_util", "@com_google_absl//absl/memory", gtest, - "//test/util:posix_error", - "//test/util:save_util", - "//test/util:test_util", ], alwayslink = 1, ) @@ -2956,9 +2972,21 @@ cc_binary( deps = [ ":ip_socket_test_util", ":socket_ipv4_udp_unbound_test_cases", - "//test/util:socket_util", "//test/util:test_main", - "//test/util:test_util", + ], +) + +cc_binary( + name = "socket_ipv4_datagram_based_socket_unbound_loopback_test", + testonly = 1, + srcs = [ + "socket_ipv4_datagram_based_socket_unbound_loopback.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv4_datagram_based_socket_unbound_test_cases", + "//test/util:test_main", ], ) diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index e90a7e411..a1216d23f 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -156,6 +156,14 @@ SocketKind ICMPv6UnboundSocket(int type) { UnboundSocketCreator(AF_INET6, type | SOCK_DGRAM, IPPROTO_ICMPV6)}; } +SocketKind IPv4RawUDPUnboundSocket(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "IPv4 Raw UDP socket"); + return SocketKind{ + description, AF_INET, type | SOCK_RAW, IPPROTO_UDP, + UnboundSocketCreator(AF_INET, type | SOCK_RAW, IPPROTO_UDP)}; +} + SocketKind IPv4UDPUnboundSocket(int type) { std::string description = absl::StrCat(DescribeSocketType(type), "IPv4 UDP socket"); diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 930e26ca5..556838356 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -95,6 +95,10 @@ SocketKind ICMPUnboundSocket(int type); // created with AF_INET6, SOCK_DGRAM, IPPROTO_ICMPV6, and the given type. SocketKind ICMPv6UnboundSocket(int type); +// IPv4RawUDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_RAW, IPPROTO_UDP, and the given type. +SocketKind IPv4RawUDPUnboundSocket(int type); + // IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket // created with AF_INET, SOCK_DGRAM, IPPROTO_UDP, and the given type. SocketKind IPv4UDPUnboundSocket(int type); diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc new file mode 100644 index 000000000..0e3e7057b --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.cc @@ -0,0 +1,299 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h" + +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/util/capability_util.h" + +namespace gvisor { +namespace testing { + +void IPv4DatagramBasedUnboundSocketTest::SetUp() { + if (GetParam().type & SOCK_RAW) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveRawIPSocketCapability())); + } +} + +// Check that dropping a group membership that does not exist fails. +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastInvalidDrop) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Unregister from a membership that we didn't have. + ip_mreq group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, + sizeof(group)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfZero) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn iface = {}; + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), + SyscallSucceeds()); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfInvalidNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn iface = {}; + iface.imr_ifindex = -1; + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfInvalidAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreq iface = {}; + iface.imr_interface.s_addr = inet_addr("255.255.255"); + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, + sizeof(iface)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetShort) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Create a valid full-sized request. + ip_mreqn iface = {}; + iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + + // Send an optlen of 1 to check that optlen is enforced. + EXPECT_THAT( + setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfDefault) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + in_addr get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + EXPECT_EQ(size, sizeof(get)); + EXPECT_EQ(get.s_addr, 0); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfDefaultReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + + // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the + // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. + // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. + EXPECT_EQ(size, sizeof(in_addr)); + + // getsockopt(IP_MULTICAST_IF) will only return the interface address which + // hasn't been set. + EXPECT_EQ(get.imr_multiaddr.s_addr, 0); + EXPECT_EQ(get.imr_address.s_addr, 0); + EXPECT_EQ(get.imr_ifindex, 0); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetAddrGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + in_addr set = {}; + set.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, + sizeof(set)), + SyscallSucceeds()); + + ip_mreqn get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + + // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the + // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. + // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. + EXPECT_EQ(size, sizeof(in_addr)); + EXPECT_EQ(get.imr_multiaddr.s_addr, set.s_addr); + EXPECT_EQ(get.imr_address.s_addr, 0); + EXPECT_EQ(get.imr_ifindex, 0); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreq set = {}; + set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, + sizeof(set)), + SyscallSucceeds()); + + ip_mreqn get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + + // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the + // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. + // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. + EXPECT_EQ(size, sizeof(in_addr)); + EXPECT_EQ(get.imr_multiaddr.s_addr, set.imr_interface.s_addr); + EXPECT_EQ(get.imr_address.s_addr, 0); + EXPECT_EQ(get.imr_ifindex, 0); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetNicGetReqn) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn set = {}; + set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, + sizeof(set)), + SyscallSucceeds()); + + ip_mreqn get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + EXPECT_EQ(size, sizeof(in_addr)); + EXPECT_EQ(get.imr_multiaddr.s_addr, 0); + EXPECT_EQ(get.imr_address.s_addr, 0); + EXPECT_EQ(get.imr_ifindex, 0); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + in_addr set = {}; + set.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, + sizeof(set)), + SyscallSucceeds()); + + in_addr get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + + EXPECT_EQ(size, sizeof(get)); + EXPECT_EQ(get.s_addr, set.s_addr); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetReqAddr) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreq set = {}; + set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, + sizeof(set)), + SyscallSucceeds()); + + in_addr get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + + EXPECT_EQ(size, sizeof(get)); + EXPECT_EQ(get.s_addr, set.imr_interface.s_addr); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, IpMulticastIfSetNic) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn set = {}; + set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, + sizeof(set)), + SyscallSucceeds()); + + in_addr get = {}; + socklen_t size = sizeof(get); + ASSERT_THAT( + getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), + SyscallSucceeds()); + EXPECT_EQ(size, sizeof(get)); + EXPECT_EQ(get.s_addr, 0); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, TestJoinGroupNoIf) { + // TODO(b/185517803): Fix for native test. + SKIP_IF(!IsRunningOnGvisor()); + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), + SyscallFailsWithErrno(ENODEV)); +} + +TEST_P(IPv4DatagramBasedUnboundSocketTest, TestJoinGroupInvalidIf) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ip_mreqn group = {}; + group.imr_address.s_addr = inet_addr("255.255.255"); + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, + sizeof(group)), + SyscallFailsWithErrno(ENODEV)); +} + +// Check that multiple memberships are not allowed on the same socket. +TEST_P(IPv4DatagramBasedUnboundSocketTest, TestMultipleJoinsOnSingleSocket) { + auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto fd = socket1->get(); + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + + EXPECT_THAT( + setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), + SyscallSucceeds()); + + EXPECT_THAT( + setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), + SyscallFailsWithErrno(EADDRINUSE)); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h new file mode 100644 index 000000000..7ab235cad --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_DATAGRAM_BASED_SOCKET_UNBOUND_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_DATAGRAM_BASED_SOCKET_UNBOUND_H_ + +#include "test/util/socket_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 datagram-based sockets. +class IPv4DatagramBasedUnboundSocketTest : public SimpleSocketTest { + void SetUp() override; +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_DATAGRAM_BASED_SOCKET_UNBOUND_H_ diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc new file mode 100644 index 000000000..9ae68e015 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P(IPv4Sockets, IPv4DatagramBasedUnboundSocketTest, + ::testing::ValuesIn(ApplyVecToVec<SocketKind>( + {IPv4UDPUnboundSocket, IPv4RawUDPUnboundSocket}, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index a4a793c86..05773a29c 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -24,10 +24,6 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/util/posix_error.h" -#include "test/util/save_util.h" -#include "test/util/socket_util.h" -#include "test/util/test_util.h" namespace gvisor { namespace testing { @@ -817,20 +813,6 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } -// Check that dropping a group membership that does not exist fails. -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Unregister from a membership that we didn't have. - ip_mreq group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - // Check that dropping a group membership prevents multicast packets from being // delivered. Default send address configured by bind and group membership // interface configured by address. @@ -941,260 +923,6 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { PosixErrorIs(EAGAIN, ::testing::_)); } -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn iface = {}; - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallSucceeds()); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn iface = {}; - iface.imr_ifindex = -1; - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreq iface = {}; - iface.imr_interface.s_addr = inet_addr("255.255.255"); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, - sizeof(iface)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - // Create a valid full-sized request. - ip_mreqn iface = {}; - iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); - - // Send an optlen of 1 to check that optlen is enforced. - EXPECT_THAT( - setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the - // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. - // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. - EXPECT_EQ(size, sizeof(in_addr)); - - // getsockopt(IP_MULTICAST_IF) will only return the interface address which - // hasn't been set. - EXPECT_EQ(get.imr_multiaddr.s_addr, 0); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - in_addr set = {}; - set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the - // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. - // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. - EXPECT_EQ(size, sizeof(in_addr)); - EXPECT_EQ(get.imr_multiaddr.s_addr, set.s_addr); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreq set = {}; - set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - // getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the - // first sizeof(struct in_addr) bytes of struct ip_mreqn as a struct in_addr. - // Conveniently, this corresponds to the field ip_mreqn::imr_multiaddr. - EXPECT_EQ(size, sizeof(in_addr)); - EXPECT_EQ(get.imr_multiaddr.s_addr, set.imr_interface.s_addr); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn set = {}; - set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - ip_mreqn get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - EXPECT_EQ(size, sizeof(in_addr)); - EXPECT_EQ(get.imr_multiaddr.s_addr, 0); - EXPECT_EQ(get.imr_address.s_addr, 0); - EXPECT_EQ(get.imr_ifindex, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - in_addr set = {}; - set.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, set.s_addr); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreq set = {}; - set.imr_interface.s_addr = htonl(INADDR_LOOPBACK); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, set.imr_interface.s_addr); -} - -TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn set = {}; - set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); - ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set, - sizeof(set)), - SyscallSucceeds()); - - in_addr get = {}; - socklen_t size = sizeof(get); - ASSERT_THAT( - getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size), - SyscallSucceeds()); - EXPECT_EQ(size, sizeof(get)); - EXPECT_EQ(get.s_addr, 0); -} - -TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) { - // TODO(b/185517803): Fix for native test. - SKIP_IF(!IsRunningOnGvisor()); - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(ENODEV)); -} - -TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - ip_mreqn group = {}; - group.imr_address.s_addr = inet_addr("255.255.255"); - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, - sizeof(group)), - SyscallFailsWithErrno(ENODEV)); -} - -// Check that multiple memberships are not allowed on the same socket. -TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) { - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - auto fd = socket1->get(); - ip_mreqn group = {}; - group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); - group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); - - EXPECT_THAT( - setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), - SyscallSucceeds()); - - EXPECT_THAT( - setsockopt(fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &group, sizeof(group)), - SyscallFailsWithErrno(EADDRINUSE)); -} - // Check that two sockets can join the same multicast group at the same time. TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) { auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc index 00930c544..bef284530 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc @@ -12,12 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <vector> - #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_ipv4_udp_unbound.h" -#include "test/util/socket_util.h" -#include "test/util/test_util.h" namespace gvisor { namespace testing { |