diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 3 |
2 files changed, 19 insertions, 4 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 935ac622e..ac5905772 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -249,6 +249,11 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // specified address is a multicast address. func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.id.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { if nicid == 0 { nicid = e.multicastNICID @@ -448,7 +453,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { if nicID == 0 { r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) if err == nil { @@ -914,8 +924,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicid := addr.NIC - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicid == 0 { return tcpip.ErrBadLocalAddress @@ -1064,3 +1074,7 @@ func (e *endpoint) State() uint32 { // TODO(b/112063468): Translate internal state to values returned by Linux. return 0 } + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4a3c30115..5cbb56120 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -97,7 +97,8 @@ func (e *endpoint) Resume(s *stack.Stack) { if err != nil { panic(err) } - } else if len(e.id.LocalAddress) != 0 { // stateBound + } else if len(e.id.LocalAddress) != 0 && !isBroadcastOrMulticast(e.id.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) } |