diff options
Diffstat (limited to 'pkg/tcpip/transport/udp/udp_test.go')
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 235 |
1 files changed, 198 insertions, 37 deletions
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4700193c2..6d7a737bd 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -34,16 +34,20 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr - testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr - V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" - - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testAddr = "\x0a\x00\x00\x02" - testPort = 4096 + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr + testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr + multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testAddr = "\x0a\x00\x00\x02" + testPort = 4096 + multicastAddr = "\xe8\x2b\xd3\xea" + multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + multicastPort = 1234 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -128,37 +132,35 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -func (c *testContext) getV6Packet() []byte { +func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { select { case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + if p.Proto != protocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) } b := make([]byte, len(p.Header)+len(p.Payload)) copy(b, p.Header) copy(b[len(p.Header):], p.Payload) - checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr)) - return b - - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") - } - - return nil -} - -func (c *testContext) getPacket() []byte { - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) + var srcAddr, dstAddr tcpip.Address + switch protocolNumber { + case ipv4.ProtocolNumber: + checkerFn = checker.IPv4 + srcAddr, dstAddr = stackAddr, testAddr + if multicast { + dstAddr = multicastAddr + } + case ipv6.ProtocolNumber: + checkerFn = checker.IPv6 + srcAddr, dstAddr = stackV6Addr, testV6Addr + if multicast { + dstAddr = multicastV6Addr + } + default: + c.t.Fatalf("unknown protocol %d", protocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) - - checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr)) + checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) return b case <-time.After(2 * time.Second): @@ -495,7 +497,7 @@ func testV4Write(c *testContext) uint16 { } // Check that we received the packet. - b := c.getPacket() + b := c.getPacket(ipv4.ProtocolNumber, false) udp := header.UDP(header.IPv4(b).Payload()) checker.IPv4(c.t, b, checker.UDP( @@ -525,7 +527,7 @@ func testV6Write(c *testContext) uint16 { } // Check that we received the packet. - b := c.getV6Packet() + b := c.getPacket(ipv6.ProtocolNumber, false) udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( @@ -682,7 +684,7 @@ func TestV6WriteOnConnected(t *testing.T) { } // Check that we received the packet. - b := c.getV6Packet() + b := c.getPacket(ipv6.ProtocolNumber, false) udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( @@ -718,7 +720,7 @@ func TestV4WriteOnConnected(t *testing.T) { } // Check that we received the packet. - b := c.getPacket() + b := c.getPacket(ipv4.ProtocolNumber, false) udp := header.UDP(header.IPv4(b).Payload()) checker.IPv4(c.t, b, checker.UDP( @@ -769,3 +771,162 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) } } + +func TestTTL(t *testing.T) { + payload := tcpip.SlicePayload(buffer.View(newPayload())) + + for _, name := range []string{"v4", "v6", "dual"} { + t.Run(name, func(t *testing.T) { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch name { + case "v4": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + var variants []string + switch name { + case "v4": + variants = []string{"v4"} + case "v6": + variants = []string{"v6"} + case "dual": + variants = []string{"v6", "mapped"} + } + + for _, variant := range variants { + t.Run(variant, func(t *testing.T) { + for _, typ := range []string{"unicast", "multicast"} { + t.Run(typ, func(t *testing.T) { + var addr tcpip.Address + var port uint16 + switch typ { + case "unicast": + port = testPort + switch variant { + case "v4": + addr = testAddr + case "mapped": + addr = testV4MappedAddr + case "v6": + addr = testV6Addr + default: + t.Fatal("unknown test variant") + } + case "multicast": + port = multicastPort + switch variant { + case "v4": + addr = multicastAddr + case "mapped": + addr = multicastV4MappedAddr + case "v6": + addr = multicastV6Addr + default: + t.Fatal("unknown test variant") + } + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + switch name { + case "v4": + case "v6": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + case "dual": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + default: + t.Fatal("unknown test variant") + } + + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + n, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) + } + + checkerFn := checker.IPv4 + switch variant { + case "v4", "mapped": + case "v6": + checkerFn = checker.IPv6 + default: + t.Fatal("unknown test variant") + } + var wantTTL uint8 + var multicast bool + switch typ { + case "unicast": + multicast = false + switch variant { + case "v4", "mapped": + ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + case "v6": + ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + default: + t.Fatal("unknown test variant") + } + case "multicast": + wantTTL = multicastTTL + multicast = true + default: + t.Fatal("unknown test variant") + } + + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch variant { + case "v4", "mapped": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + b := c.getPacket(networkProtocolNumber, multicast) + checkerFn(c.t, b, + checker.TTL(wantTTL), + checker.UDP( + checker.DstPort(port), + ), + ) + }) + } + }) + } + }) + } +} |