diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 115 |
1 files changed, 87 insertions, 28 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index fa8f02e46..9c3881d63 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -69,17 +69,19 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - id stack.TransportEndpointID - state endpointState - bindNICID tcpip.NICID - regNICID tcpip.NICID - route stack.Route `state:"manual"` - dstPort uint16 - v6only bool - multicastTTL uint8 - reusePort bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + regNICID tcpip.NICID + route stack.Route `state:"manual"` + dstPort uint16 + v6only bool + multicastTTL uint8 + multicastAddr tcpip.Address + multicastNICID tcpip.NICID + reusePort bool // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -251,6 +253,33 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi return true, nil } +// connectRoute establishes a route to the specified interface or the +// configured multicast interface if no interface is specified and the +// specified address is a multicast address. +func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return stack.Route{}, 0, 0, err + } + + localAddr := e.id.LocalAddress + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { + if nicid == 0 { + nicid = e.multicastNICID + } + if localAddr == "" { + localAddr = e.multicastAddr + } + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto) + if err != nil { + return stack.Route{}, 0, 0, err + } + return r, nicid, netProto, nil +} + // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { @@ -318,15 +347,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c nicid = e.bindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to, false) - if err != nil { - return 0, nil, err - } - - // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto) + r, _, _, err := e.connectRoute(nicid, *to) if err != nil { return 0, nil, err } @@ -394,6 +415,42 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastTTL = uint8(v) e.mu.Unlock() + case tcpip.MulticastInterfaceOption: + e.mu.Lock() + defer e.mu.Unlock() + + fa := tcpip.FullAddress{Addr: v.InterfaceAddr} + netProto, err := e.checkV4Mapped(&fa, false) + if err != nil { + return err + } + nic := v.NIC + addr := fa.Addr + + if nic == 0 && addr == "" { + e.multicastAddr = "" + e.multicastNICID = 0 + break + } + + if nic != 0 { + if !e.stack.CheckNIC(nic) { + return tcpip.ErrBadLocalAddress + } + } else { + nic = e.stack.CheckLocalAddress(0, netProto, addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + } + + if e.bindNICID != 0 && e.bindNICID != nic { + return tcpip.ErrInvalidEndpointState + } + + e.multicastNICID = nic + e.multicastAddr = addr + case tcpip.AddMembershipOption: nicID := v.NIC if v.InterfaceAddr != header.IPv4Any { @@ -445,7 +502,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.reusePort = v != 0 e.mu.Unlock() - return nil } return nil } @@ -501,6 +557,15 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + *o = tcpip.MulticastInterfaceOption{ + e.multicastNICID, + e.multicastAddr, + } + e.mu.Unlock() + return nil + case *tcpip.ReusePortOption: e.mu.RLock() v := e.reusePort @@ -610,13 +675,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + r, nicid, netProto, err := e.connectRoute(nicid, addr) if err != nil { return err } |