diff options
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 10 | ||||
-rw-r--r-- | test/syscalls/linux/socket_ip_udp_generic.cc | 11 |
2 files changed, 11 insertions, 10 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3693abae5..cdde6a023 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -459,6 +459,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastAddr = addr case tcpip.AddMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return tcpip.ErrInvalidOptionValue + } + nicID := v.NIC if v.InterfaceAddr == header.IPv4Any { if nicID == 0 { @@ -475,7 +479,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - // TODO: check that v.MulticastAddr is a multicast address. if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { return err } @@ -486,6 +489,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) case tcpip.RemoveMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return tcpip.ErrInvalidOptionValue + } + nicID := v.NIC if v.InterfaceAddr == header.IPv4Any { if nicID == 0 { @@ -502,7 +509,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownDevice } - // TODO: check that v.MulticastAddr is a multicast address. if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { return err } diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 58d1c846d..197783e55 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -121,14 +121,9 @@ TEST_P(UDPSocketPairTest, SetEmptyIPAddMembership) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); struct ip_mreqn req = {}; - int ret = setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &req, - sizeof(req)); - // FIXME: gVisor returns the incorrect errno. - if (IsRunningOnGvisor()) { - EXPECT_THAT(ret, SyscallFails()); - } else { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINVAL)); - } + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &req, sizeof(req)), + SyscallFailsWithErrno(EINVAL)); } } // namespace testing |