summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go10
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc11
2 files changed, 11 insertions, 10 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 3693abae5..cdde6a023 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -459,6 +459,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastAddr = addr
case tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
nicID := v.NIC
if v.InterfaceAddr == header.IPv4Any {
if nicID == 0 {
@@ -475,7 +479,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- // TODO: check that v.MulticastAddr is a multicast address.
if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
return err
}
@@ -486,6 +489,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
case tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
nicID := v.NIC
if v.InterfaceAddr == header.IPv4Any {
if nicID == 0 {
@@ -502,7 +509,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownDevice
}
- // TODO: check that v.MulticastAddr is a multicast address.
if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
return err
}
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 58d1c846d..197783e55 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -121,14 +121,9 @@ TEST_P(UDPSocketPairTest, SetEmptyIPAddMembership) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct ip_mreqn req = {};
- int ret = setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &req,
- sizeof(req));
- // FIXME: gVisor returns the incorrect errno.
- if (IsRunningOnGvisor()) {
- EXPECT_THAT(ret, SyscallFails());
- } else {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINVAL));
- }
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &req, sizeof(req)),
+ SyscallFailsWithErrno(EINVAL));
}
} // namespace testing