summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go123
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go235
2 files changed, 305 insertions, 53 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d091a6196..5de518a55 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -70,19 +70,24 @@ type endpoint struct {
rcvTimestamp bool
// The following fields are protected by the mu mutex.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- id stack.TransportEndpointID
- state endpointState
- bindNICID tcpip.NICID
- regNICID tcpip.NICID
- route stack.Route `state:"manual"`
- dstPort uint16
- v6only bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
@@ -92,11 +97,29 @@ type endpoint struct {
effectiveNetProtos []tcpip.NetworkProtocolNumber
}
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
- stack: stack,
- netProto: netProto,
- waiterQueue: waiterQueue,
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
@@ -135,6 +158,11 @@ func (e *endpoint) Close() {
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
@@ -329,8 +357,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc
return 0, err
}
+ ttl := route.DefaultTTL()
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ }
+
vv := buffer.NewVectorisedView(len(v), []buffer.View{v})
- if err := sendUDP(route, vv, e.id.LocalPort, dstPort); err != nil {
+ if err := sendUDP(route, vv, e.id.LocalPort, dstPort, ttl); err != nil {
return 0, err
}
return uintptr(len(v)), nil
@@ -365,6 +398,56 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.rcvMu.Lock()
e.rcvTimestamp = v != 0
e.rcvMu.Unlock()
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.multicastTTL = uint8(v)
+
+ case tcpip.AddMembershipOption:
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrNoRoute
+ }
+
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
+
+ case tcpip.RemoveMembershipOption:
+ nicID := v.NIC
+ if v.InterfaceAddr != header.IPv4Any {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrNoRoute
+ }
+
+ // TODO: check that v.MulticastAddr is a multicast address.
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ for i, mem := range e.multicastMemberships {
+ if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
+ // Only remove the first match, so that each added membership above is
+ // paired with exactly 1 removal.
+ e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+ break
+ }
+ }
}
return nil
}
@@ -421,6 +504,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = 1
}
e.rcvMu.Unlock()
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
}
return tcpip.ErrUnknownProtocolOption
@@ -428,7 +517,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16) *tcpip.Error {
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
@@ -454,7 +543,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
- return r.WritePacket(&hdr, data, ProtocolNumber)
+ return r.WritePacket(&hdr, data, ProtocolNumber, ttl)
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
@@ -581,7 +670,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
return tcpip.ErrNotConnected
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4700193c2..6d7a737bd 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -34,16 +34,20 @@ import (
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
- testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
- V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
-
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testAddr = "\x0a\x00\x00\x02"
- testPort = 4096
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
+ testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
+ multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+ multicastAddr = "\xe8\x2b\xd3\xea"
+ multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+ multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -128,37 +132,35 @@ func (c *testContext) createV6Endpoint(v6only bool) {
}
}
-func (c *testContext) getV6Packet() []byte {
+func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
select {
case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ if p.Proto != protocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
- checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
- return b
-
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
- }
-
- return nil
-}
-
-func (c *testContext) getPacket() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
+ var srcAddr, dstAddr tcpip.Address
+ switch protocolNumber {
+ case ipv4.ProtocolNumber:
+ checkerFn = checker.IPv4
+ srcAddr, dstAddr = stackAddr, testAddr
+ if multicast {
+ dstAddr = multicastAddr
+ }
+ case ipv6.ProtocolNumber:
+ checkerFn = checker.IPv6
+ srcAddr, dstAddr = stackV6Addr, testV6Addr
+ if multicast {
+ dstAddr = multicastV6Addr
+ }
+ default:
+ c.t.Fatalf("unknown protocol %d", protocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
+ checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
return b
case <-time.After(2 * time.Second):
@@ -495,7 +497,7 @@ func testV4Write(c *testContext) uint16 {
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -525,7 +527,7 @@ func testV6Write(c *testContext) uint16 {
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -682,7 +684,7 @@ func TestV6WriteOnConnected(t *testing.T) {
}
// Check that we received the packet.
- b := c.getV6Packet()
+ b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
@@ -718,7 +720,7 @@ func TestV4WriteOnConnected(t *testing.T) {
}
// Check that we received the packet.
- b := c.getPacket()
+ b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
@@ -769,3 +771,162 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
}
}
+
+func TestTTL(t *testing.T) {
+ payload := tcpip.SlicePayload(buffer.View(newPayload()))
+
+ for _, name := range []string{"v4", "v6", "dual"} {
+ t.Run(name, func(t *testing.T) {
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch name {
+ case "v4":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6", "dual":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var variants []string
+ switch name {
+ case "v4":
+ variants = []string{"v4"}
+ case "v6":
+ variants = []string{"v6"}
+ case "dual":
+ variants = []string{"v6", "mapped"}
+ }
+
+ for _, variant := range variants {
+ t.Run(variant, func(t *testing.T) {
+ for _, typ := range []string{"unicast", "multicast"} {
+ t.Run(typ, func(t *testing.T) {
+ var addr tcpip.Address
+ var port uint16
+ switch typ {
+ case "unicast":
+ port = testPort
+ switch variant {
+ case "v4":
+ addr = testAddr
+ case "mapped":
+ addr = testV4MappedAddr
+ case "v6":
+ addr = testV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ port = multicastPort
+ switch variant {
+ case "v4":
+ addr = multicastAddr
+ case "mapped":
+ addr = multicastV4MappedAddr
+ case "v6":
+ addr = multicastV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ switch name {
+ case "v4":
+ case "v6":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ case "dual":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
+ }
+
+ checkerFn := checker.IPv4
+ switch variant {
+ case "v4", "mapped":
+ case "v6":
+ checkerFn = checker.IPv6
+ default:
+ t.Fatal("unknown test variant")
+ }
+ var wantTTL uint8
+ var multicast bool
+ switch typ {
+ case "unicast":
+ multicast = false
+ switch variant {
+ case "v4", "mapped":
+ ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ case "v6":
+ ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ wantTTL = multicastTTL
+ multicast = true
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch variant {
+ case "v4", "mapped":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ b := c.getPacket(networkProtocolNumber, multicast)
+ checkerFn(c.t, b,
+ checker.TTL(wantTTL),
+ checker.UDP(
+ checker.DstPort(port),
+ ),
+ )
+ })
+ }
+ })
+ }
+ })
+ }
+}