diff options
author | Chris Kuiper <ckuiper@google.com> | 2019-07-19 09:27:33 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-07-19 09:29:04 -0700 |
commit | 0e040ba6e87f5fdcb23854909aad39aa1883925f (patch) | |
tree | 49e6794739e5f6c4a67219c40320a235c10e967a /pkg/tcpip/transport | |
parent | eefa817cfdb04ff07e7069396f21bd6ba2c89957 (diff) |
Handle interfaceAddr and NIC options separately for IP_MULTICAST_IF
This tweaks the handling code for IP_MULTICAST_IF to ignore the InterfaceAddr
if a NICID is given.
PiperOrigin-RevId: 258982541
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 339 |
2 files changed, 215 insertions, 129 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 70f4a2b8c..91f89a781 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -252,7 +252,7 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac if nicid == 0 { nicid = e.multicastNICID } - if localAddr == "" { + if localAddr == "" && nicid == 0 { localAddr = e.multicastAddr } } @@ -675,6 +675,9 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { netProto := e.netProto + if len(addr.Addr) == 0 { + return netProto, nil + } if header.IsV4MappedAddress(addr.Addr) { // Fail if using a v4 mapped address on a v6only endpoint. if e.v6only { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 75129a2ff..958d5712e 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -847,9 +847,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } -func TestTTL(t *testing.T) { - payload := tcpip.SlicePayload(buffer.View(newPayload())) - +func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) { for _, name := range []string{"v4", "v6", "dual"} { t.Run(name, func(t *testing.T) { var networkProtocolNumber tcpip.NetworkProtocolNumber @@ -874,134 +872,219 @@ func TestTTL(t *testing.T) { for _, variant := range variants { t.Run(variant, func(t *testing.T) { - for _, typ := range []string{"unicast", "multicast"} { - t.Run(typ, func(t *testing.T) { - var addr tcpip.Address - var port uint16 - switch typ { - case "unicast": - port = testPort - switch variant { - case "v4": - addr = testAddr - case "mapped": - addr = testV4MappedAddr - case "v6": - addr = testV6Addr - default: - t.Fatal("unknown test variant") - } - case "multicast": - port = multicastPort - switch variant { - case "v4": - addr = multicastAddr - case "mapped": - addr = multicastV4MappedAddr - case "v6": - addr = multicastV6Addr - default: - t.Fatal("unknown test variant") - } - default: - t.Fatal("unknown test variant") - } - - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - switch name { - case "v4": - case "v6": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - case "dual": - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } - default: - t.Fatal("unknown test variant") - } + optFunc(t, name, networkProtocolNumber, variant) + }) + } + }) + } +} - const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } +func TestTTL(t *testing.T) { + payload := tcpip.SlicePayload(buffer.View(newPayload())) - n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) - if err != nil { - c.t.Fatalf("Write failed: %v", err) - } - if n != uintptr(len(payload)) { - c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) - } + setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { + for _, typ := range []string{"unicast", "multicast"} { + t.Run(typ, func(t *testing.T) { + var addr tcpip.Address + var port uint16 + switch typ { + case "unicast": + port = testPort + switch variant { + case "v4": + addr = testAddr + case "mapped": + addr = testV4MappedAddr + case "v6": + addr = testV6Addr + default: + t.Fatal("unknown test variant") + } + case "multicast": + port = multicastPort + switch variant { + case "v4": + addr = multicastAddr + case "mapped": + addr = multicastV4MappedAddr + case "v6": + addr = multicastV6Addr + default: + t.Fatal("unknown test variant") + } + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + switch name { + case "v4": + case "v6": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + case "dual": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + default: + t.Fatal("unknown test variant") + } + + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) + } + + checkerFn := checker.IPv4 + switch variant { + case "v4", "mapped": + case "v6": + checkerFn = checker.IPv6 + default: + t.Fatal("unknown test variant") + } + var wantTTL uint8 + var multicast bool + switch typ { + case "unicast": + multicast = false + switch variant { + case "v4", "mapped": + ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + case "v6": + ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + default: + t.Fatal("unknown test variant") + } + case "multicast": + wantTTL = multicastTTL + multicast = true + default: + t.Fatal("unknown test variant") + } + + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch variant { + case "v4", "mapped": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + b := c.getPacket(networkProtocolNumber, multicast) + checkerFn(c.t, b, + checker.TTL(wantTTL), + checker.UDP( + checker.DstPort(port), + ), + ) + }) + } + }) +} - checkerFn := checker.IPv4 - switch variant { - case "v4", "mapped": - case "v6": - checkerFn = checker.IPv6 - default: - t.Fatal("unknown test variant") - } - var wantTTL uint8 - var multicast bool - switch typ { - case "unicast": - multicast = false - switch variant { - case "v4", "mapped": - ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - case "v6": - ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) - if err != nil { - t.Fatal(err) - } - wantTTL = ep.DefaultTTL() - ep.Close() - default: - t.Fatal("unknown test variant") - } - case "multicast": - wantTTL = multicastTTL - multicast = true - default: - t.Fatal("unknown test variant") +func TestMulticastInterfaceOption(t *testing.T) { + setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + var mcastAddr, localIfAddr tcpip.Address + switch variant { + case "v4": + mcastAddr = multicastAddr + localIfAddr = stackAddr + case "mapped": + mcastAddr = multicastV4MappedAddr + localIfAddr = stackAddr + case "v6": + mcastAddr = multicastV6Addr + localIfAddr = stackV6Addr + default: + t.Fatal("unknown test variant") + } + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: mcastAddr, + Port: multicastPort, } - - var networkProtocolNumber tcpip.NetworkProtocolNumber - switch variant { - case "v4", "mapped": - networkProtocolNumber = ipv4.ProtocolNumber - case "v6": - networkProtocolNumber = ipv6.ProtocolNumber - default: - t.Fatal("unknown test variant") + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %v", err) } - - b := c.getPacket(networkProtocolNumber, multicast) - checkerFn(c.t, b, - checker.TTL(wantTTL), - checker.UDP( - checker.DstPort(port), - ), - ) - }) - } - }) - } - }) - } + } + + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %v", err) + } + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) + } + }) + } + }) + } + }) } |