diff options
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 34 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_udp_generic.cc | 84 |
2 files changed, 103 insertions, 15 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f7636e056..6e95fd448 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -1193,24 +1193,30 @@ func copyInMulticastRequest(optVal []byte) (linux.InetMulticastRequestWithNIC, * return req, nil } -// reduceToByte ORs all of the bytes in the input. -func reduceToByte(buf []byte) byte { - var out byte - for _, b := range buf { - out |= b +// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf. +// +// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options. +func parseIntOrChar(buf []byte) (int32, *syserr.Error) { + if len(buf) == 0 { + return 0, syserr.ErrInvalidArgument + } + + if len(buf) >= sizeOfInt32 { + return int32(usermem.ByteOrder.Uint32(buf)), nil } - return out + + return int32(buf[0]), nil } // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.IP_MULTICAST_TTL: - if len(optVal) < sizeOfInt32 { - return syserr.ErrInvalidArgument + v, err := parseIntOrChar(optVal) + if err != nil { + return err } - v := int32(usermem.ByteOrder.Uint32(optVal)) if v == -1 { // Linux translates -1 to 1. v = 1 @@ -1260,15 +1266,13 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s })) case linux.IP_MULTICAST_LOOP: - if len(optVal) < 1 { - return syserr.ErrInvalidArgument - } - if len(optVal) > sizeOfInt32 { - optVal = optVal[:sizeOfInt32] + v, err := parseIntOrChar(optVal) + if err != nil { + return err } return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(reduceToByte(optVal) != 0), + tcpip.MulticastLoopOption(v != 0), )) case linux.MCAST_JOIN_GROUP: diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 197783e55..432017b12 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -117,6 +117,23 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLAboveMax) { SyscallFailsWithErrno(EINVAL)); } +TEST_P(UDPSocketPairTest, SetUDPMulticastTTLChar) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr char kArbitrary = 6; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + &kArbitrary, sizeof(kArbitrary)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kArbitrary); +} + TEST_P(UDPSocketPairTest, SetEmptyIPAddMembership) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -126,5 +143,72 @@ TEST_P(UDPSocketPairTest, SetEmptyIPAddMembership) { SyscallFailsWithErrno(EINVAL)); } +TEST_P(UDPSocketPairTest, MulticastLoopDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +TEST_P(UDPSocketPairTest, SetMulticastLoop) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr char kSockOptOnChar = kSockOptOn; + constexpr char kSockOptOffChar = kSockOptOff; + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &kSockOptOffChar, sizeof(kSockOptOffChar)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &kSockOptOnChar, sizeof(kSockOptOnChar)), + SyscallSucceeds()); + + EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP, + &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + } // namespace testing } // namespace gvisor |