From abe9d9f67f2c2c696ef26690fa8518dfc4e28728 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 6 Jan 2021 11:39:14 -0800 Subject: Support add/remove IPv6 multicast group sock opt IPv4 was always supported but UDP never supported joining/leaving IPv6 multicast groups via socket options. Add: IPPROTO_IPV6, IPV6_JOIN_GROUP/IPV6_ADD_MEMBERSHIP Remove: IPPROTO_IPV6, IPV6_LEAVE_GROUP/IPV6_DROP_MEMBERSHIP Test: integration_test.TestUDPAddRemoveMembershipSocketOption PiperOrigin-RevId: 350396072 --- .../tests/integration/multicast_broadcast_test.go | 474 ++++++++++++++------- 1 file changed, 328 insertions(+), 146 deletions(-) (limited to 'pkg/tcpip/tests/integration') diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 2e59f6a42..20f8a7e6c 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -35,6 +35,9 @@ import ( const ( defaultMTU = 1280 ttl = 255 + + remotePort = 5555 + localPort = 80 ) var ( @@ -151,11 +154,11 @@ func TestPingMulticastBroadcast(t *testing.T) { } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote @@ -215,163 +218,219 @@ func TestPingMulticastBroadcast(t *testing.T) { } +func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + // TestIncomingMulticastAndBroadcast tests receiving a packet destined to some // multicast or broadcast address. func TestIncomingMulticastAndBroadcast(t *testing.T) { - const ( - nicID = 1 - remotePort = 5555 - localPort = 80 - ) + const nicID = 1 data := []byte{1, 2, 3, 4} - rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { - payloadLen := header.UDPMinimumSize + len(data) - totalLen := header.IPv4MinimumSize + payloadLen - hdr := buffer.NewPrependable(totalLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: remotePort, - DstPort: localPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(udp.ProtocolNumber), - TTL: ttl, - SrcAddr: remoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { - payloadLen := header.UDPMinimumSize + len(data) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: remotePort, - DstPort: localPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - TransportProtocol: udp.ProtocolNumber, - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool + name string + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + localAddr tcpip.AddressWithPrefix + rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool }{ { - name: "IPv4 unicast binding to unicast", - bindAddr: ipv4Addr.Address, - dstAddr: ipv4Addr.Address, - expectRx: true, + name: "IPv4 unicast binding to unicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, }, { - name: "IPv4 unicast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: ipv4Addr.Address, - expectRx: false, + name: "IPv4 unicast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, }, { - name: "IPv4 unicast binding to wildcard", - dstAddr: ipv4Addr.Address, - expectRx: true, + name: "IPv4 unicast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4Addr.Address, + expectRx: true, }, { - name: "IPv4 directed broadcast binding to subnet broadcast", - bindAddr: ipv4SubnetBcast, - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 directed broadcast binding to subnet broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 directed broadcast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: ipv4SubnetBcast, - expectRx: false, + name: "IPv4 directed broadcast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, }, { - name: "IPv4 directed broadcast binding to wildcard", - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 directed broadcast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 broadcast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: header.IPv4Broadcast, - expectRx: true, + name: "IPv4 broadcast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, }, { - name: "IPv4 broadcast binding to subnet broadcast", - bindAddr: ipv4SubnetBcast, - dstAddr: header.IPv4Broadcast, - expectRx: false, + name: "IPv4 broadcast binding to subnet broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, }, { - name: "IPv4 broadcast binding to wildcard", - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 broadcast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to all-systems multicast", - bindAddr: header.IPv4AllSystems, - dstAddr: header.IPv4AllSystems, - expectRx: true, + name: "IPv4 all-systems multicast binding to all-systems multicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to wildcard", - dstAddr: header.IPv4AllSystems, - expectRx: true, + name: "IPv4 all-systems multicast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: header.IPv4AllSystems, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to unicast", - bindAddr: ipv4Addr.Address, - dstAddr: header.IPv4AllSystems, - expectRx: false, + name: "IPv4 all-systems multicast binding to unicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, }, // IPv6 has no notion of a broadcast. { - name: "IPv6 unicast binding to wildcard", - dstAddr: ipv6Addr.Address, - expectRx: true, + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + expectRx: true, }, { - name: "IPv6 broadcast-like address binding to wildcard", - dstAddr: ipv6SubnetBcast, - expectRx: false, + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + expectRx: false, }, } @@ -385,41 +444,24 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) - } - ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) - } - - var netproto tcpip.NetworkProtocolNumber - var rxUDP func(*channel.Endpoint, tcpip.Address) - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - netproto = header.IPv4ProtocolNumber - rxUDP = rxIPv4UDP - case header.IPv6AddressSize: - netproto = header.IPv6ProtocolNumber - rxUDP = rxIPv6UDP - default: - t.Fatalf("got unexpected address length = %d bytes", l) + protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) } defer ep.Close() bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) } - rxUDP(e, test.dstAddr) + test.rxUDP(e, test.remoteAddr, test.dstAddr, data) if gotPayload, _, err := ep.Read(nil); test.expectRx { if err != nil { t.Fatalf("Read(nil): %s", err) @@ -476,7 +518,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, } if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -516,12 +558,12 @@ func TestReuseAddrAndBroadcast(t *testing.T) { bindAddr := tcpip.FullAddress{Port: localPort} if bindWildcard { if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) } } else { bindAddr.Addr = test.broadcastAddr if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) } } @@ -557,3 +599,143 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }) } } + +func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { + const ( + nicID = 1 + ) + + data := []byte{1, 2, 3, 4} + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + localAddr tcpip.AddressWithPrefix + rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) + multicastAddr tcpip.Address + }{ + { + name: "IPv4 unicast binding to unicast", + multicastAddr: "\xe0\x01\x02\x03", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04", + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + }, + } + + subTests := []struct { + name string + specifyNICID bool + specifyNICAddr bool + }{ + { + name: "Specify NIC ID and NIC address", + specifyNICID: true, + specifyNICAddr: true, + }, + { + name: "Don't specify NIC ID or NIC address", + specifyNICID: false, + specifyNICAddr: false, + }, + { + name: "Specify NIC ID but don't specify NIC address", + specifyNICID: true, + specifyNICAddr: false, + }, + { + name: "Don't specify NIC ID but specify NIC address", + specifyNICID: false, + specifyNICAddr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + + // Set the route table so that UDP can find a NIC that is + // routable to the multicast address when the NIC isn't specified. + if !subTest.specifyNICID && !subTest.specifyNICAddr { + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + + memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr} + if subTest.specifyNICID { + memOpt.NIC = nicID + } + if subTest.specifyNICAddr { + memOpt.InterfaceAddr = test.localAddr.Address + } + + // We should receive UDP packets to the group once we join the + // multicast group. + addOpt := tcpip.AddMembershipOption(memOpt) + if err := ep.SetSockOpt(&addOpt); err != nil { + t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) + } + test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) + if gotPayload, _, err := ep.Read(nil); err != nil { + t.Fatalf("ep.Read(nil): %s", err) + } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + + // We should not receive UDP packets to the group once we leave + // the multicast group. + removeOpt := tcpip.RemoveMembershipOption(memOpt) + if err := ep.SetSockOpt(&removeOpt); err != nil { + t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) + } + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got ep.Read(nil) = (%x, _, %s), want = (nil, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } + }) + } +} -- cgit v1.2.3