diff options
Diffstat (limited to 'pkg/tcpip/tests/integration')
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/forward_test.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 41 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/loopback_test.go | 31 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 532 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/route_test.go | 69 |
6 files changed, 521 insertions, 219 deletions
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 800025fb9..ca1e88e99 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -15,10 +15,12 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/nested", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 39343b966..60054d6ef 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -15,13 +15,16 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -31,6 +34,33 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) +var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + func TestForwarding(t *testing.T) { const ( host1NICID = 1 @@ -209,16 +239,16 @@ func TestForwarding(t *testing.T) { host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil { + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) } - if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil { + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) } - if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } @@ -353,24 +383,33 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, len(data), opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(v, buffer.View(data)); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(data), + Total: len(data), + RemoteAddr: tcpip.FullAddress{Addr: expectedFrom}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != expectedFrom { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom) + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() } - return addr + return res.RemoteAddr } addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index bf8a1241f..209da3903 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -15,14 +15,14 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -87,21 +87,21 @@ func TestPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - icmpBuf func(*testing.T) buffer.View + icmpBuf func(*testing.T) []byte }{ { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) buffer.View { + icmpBuf: func(t *testing.T) []byte { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) hdr.SetType(header.ICMPv4Echo) if n := copy(hdr.Payload(), data[:]); n != len(data) { t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) } - return buffer.View(hdr) + return hdr }, }, { @@ -109,14 +109,14 @@ func TestPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) buffer.View { + icmpBuf: func(t *testing.T) []byte { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) hdr.SetType(header.ICMPv6EchoRequest) if n := copy(hdr.Payload(), data[:]); n != len(data) { t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) } - return buffer.View(hdr) + return hdr }, }, } @@ -133,10 +133,10 @@ func TestPing(t *testing.T) { host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) - if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } @@ -201,16 +201,25 @@ func TestPing(t *testing.T) { // Wait for the endpoint to be readable. <-waiterCH - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, len(icmpBuf), opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) } - if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != test.remoteAddr { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr) + if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } }) } diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index baaa741cd..cf9e86c3c 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -15,12 +15,14 @@ package integration_test import ( + "bytes" "testing" "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -238,21 +240,28 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) } - var addr tcpip.FullAddress - if gotPayload, _, err := rep.Read(&addr); test.expectRx { + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + if res, err := rep.Read(&buf, len(data), opts); test.expectRx { if err != nil { - t.Fatalf("reep.Read(_): %s", err) - } - if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err) } - if addr.Addr != test.addAddress.AddressWithPrefix.Address { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{ + Addr: test.addAddress.AddressWithPrefix.Address, + }, + }, res, + checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"), + ); diff != "" { + t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff) } - } else { - if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + } else if err != tcpip.ErrWouldBlock { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) } }) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 8be791a00..fae6c256a 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -15,12 +15,14 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -35,6 +37,9 @@ import ( const ( defaultMTU = 1280 ttl = 255 + + remotePort = 5555 + localPort = 80 ) var ( @@ -96,11 +101,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -151,11 +156,11 @@ func TestPingMulticastBroadcast(t *testing.T) { } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote @@ -215,163 +220,219 @@ func TestPingMulticastBroadcast(t *testing.T) { } +func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + // TestIncomingMulticastAndBroadcast tests receiving a packet destined to some // multicast or broadcast address. func TestIncomingMulticastAndBroadcast(t *testing.T) { - const ( - nicID = 1 - remotePort = 5555 - localPort = 80 - ) + const nicID = 1 data := []byte{1, 2, 3, 4} - rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { - payloadLen := header.UDPMinimumSize + len(data) - totalLen := header.IPv4MinimumSize + payloadLen - hdr := buffer.NewPrependable(totalLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: remotePort, - DstPort: localPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(udp.ProtocolNumber), - TTL: ttl, - SrcAddr: remoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { - payloadLen := header.UDPMinimumSize + len(data) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: remotePort, - DstPort: localPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool + name string + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + localAddr tcpip.AddressWithPrefix + rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool }{ { - name: "IPv4 unicast binding to unicast", - bindAddr: ipv4Addr.Address, - dstAddr: ipv4Addr.Address, - expectRx: true, + name: "IPv4 unicast binding to unicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, }, { - name: "IPv4 unicast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: ipv4Addr.Address, - expectRx: false, + name: "IPv4 unicast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, }, { - name: "IPv4 unicast binding to wildcard", - dstAddr: ipv4Addr.Address, - expectRx: true, + name: "IPv4 unicast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4Addr.Address, + expectRx: true, }, { - name: "IPv4 directed broadcast binding to subnet broadcast", - bindAddr: ipv4SubnetBcast, - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 directed broadcast binding to subnet broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 directed broadcast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: ipv4SubnetBcast, - expectRx: false, + name: "IPv4 directed broadcast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, }, { - name: "IPv4 directed broadcast binding to wildcard", - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 directed broadcast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 broadcast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: header.IPv4Broadcast, - expectRx: true, + name: "IPv4 broadcast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, }, { - name: "IPv4 broadcast binding to subnet broadcast", - bindAddr: ipv4SubnetBcast, - dstAddr: header.IPv4Broadcast, - expectRx: false, + name: "IPv4 broadcast binding to subnet broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, }, { - name: "IPv4 broadcast binding to wildcard", - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 broadcast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to all-systems multicast", - bindAddr: header.IPv4AllSystems, - dstAddr: header.IPv4AllSystems, - expectRx: true, + name: "IPv4 all-systems multicast binding to all-systems multicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to wildcard", - dstAddr: header.IPv4AllSystems, - expectRx: true, + name: "IPv4 all-systems multicast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: header.IPv4AllSystems, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to unicast", - bindAddr: ipv4Addr.Address, - dstAddr: header.IPv4AllSystems, - expectRx: false, + name: "IPv4 all-systems multicast binding to unicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, }, // IPv6 has no notion of a broadcast. { - name: "IPv6 unicast binding to wildcard", - dstAddr: ipv6Addr.Address, - expectRx: true, + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + expectRx: true, }, { - name: "IPv6 broadcast-like address binding to wildcard", - dstAddr: ipv6SubnetBcast, - expectRx: false, + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + expectRx: false, }, } @@ -385,52 +446,41 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) - } - ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) - } - - var netproto tcpip.NetworkProtocolNumber - var rxUDP func(*channel.Endpoint, tcpip.Address) - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - netproto = header.IPv4ProtocolNumber - rxUDP = rxIPv4UDP - case header.IPv6AddressSize: - netproto = header.IPv6ProtocolNumber - rxUDP = rxIPv6UDP - default: - t.Fatalf("got unexpected address length = %d bytes", l) + protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) } defer ep.Close() bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) } - rxUDP(e, test.dstAddr) - if gotPayload, _, err := ep.Read(nil); test.expectRx { + test.rxUDP(e, test.remoteAddr, test.dstAddr, data) + var buf bytes.Buffer + var opts tcpip.ReadOptions + if res, err := ep.Read(&buf, len(data), opts); test.expectRx { if err != nil { - t.Fatalf("Read(nil): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - } else { - if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + } else if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) } }) } @@ -476,7 +526,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, } if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -516,12 +566,12 @@ func TestReuseAddrAndBroadcast(t *testing.T) { bindAddr := tcpip.FullAddress{Port: localPort} if bindWildcard { if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) } } else { bindAddr.Addr = test.broadcastAddr if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) } } @@ -547,9 +597,19 @@ func TestReuseAddrAndBroadcast(t *testing.T) { // Wait for the endpoint to become readable. <-rep.ch - if gotPayload, _, err := rep.ep.Read(nil); err != nil { - t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) - } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + var buf bytes.Buffer + result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{}) + if err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) + continue + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff) + } + if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) } } @@ -557,3 +617,153 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }) } } + +func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { + const ( + nicID = 1 + ) + + data := []byte{1, 2, 3, 4} + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + localAddr tcpip.AddressWithPrefix + rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) + multicastAddr tcpip.Address + }{ + { + name: "IPv4 unicast binding to unicast", + multicastAddr: "\xe0\x01\x02\x03", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04", + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + }, + } + + subTests := []struct { + name string + specifyNICID bool + specifyNICAddr bool + }{ + { + name: "Specify NIC ID and NIC address", + specifyNICID: true, + specifyNICAddr: true, + }, + { + name: "Don't specify NIC ID or NIC address", + specifyNICID: false, + specifyNICAddr: false, + }, + { + name: "Specify NIC ID but don't specify NIC address", + specifyNICID: true, + specifyNICAddr: false, + }, + { + name: "Don't specify NIC ID but specify NIC address", + specifyNICID: false, + specifyNICAddr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + + // Set the route table so that UDP can find a NIC that is + // routable to the multicast address when the NIC isn't specified. + if !subTest.specifyNICID && !subTest.specifyNICAddr { + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + + memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr} + if subTest.specifyNICID { + memOpt.NIC = nicID + } + if subTest.specifyNICAddr { + memOpt.InterfaceAddr = test.localAddr.Address + } + + // We should receive UDP packets to the group once we join the + // multicast group. + addOpt := tcpip.AddMembershipOption(memOpt) + if err := ep.SetSockOpt(&addOpt); err != nil { + t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) + } + test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) + var buf bytes.Buffer + result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("ep.Read: %s", err) + } else { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } + + // We should not receive UDP packets to the group once we leave + // the multicast group. + removeOpt := tcpip.RemoveMembershipOption(memOpt) + if err := ep.SetSockOpt(&removeOpt); err != nil { + t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) + } + if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 02fc47015..52cf89b54 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -15,11 +15,14 @@ package integration_test import ( + "bytes" + "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -203,16 +206,25 @@ func TestLocalPing(t *testing.T) { // Wait for the endpoint to become readable. <-ch - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, math.MaxUint16, opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err) } - if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } test.checkLinkEndpoint(t, e) @@ -338,14 +350,27 @@ func TestLocalUDP(t *testing.T) { <-serverCH var clientAddr tcpip.FullAddress - if v, _, err := server.Read(&clientAddr); err != nil { + var readBuf bytes.Buffer + if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("server.Read(_): %s", err) } else { - if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { - t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + clientAddr = read.RemoteAddr + + if diff := cmp.Diff(tcpip.ReadResult{ + Count: readBuf.Len(), + Total: readBuf.Len(), + RemoteAddr: tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + }, + }, read, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff) } - if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { - t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() @@ -367,15 +392,23 @@ func TestLocalUDP(t *testing.T) { // Wait for the client endpoint to become readable. <-clientCH - var gotServerAddr tcpip.FullAddress - if v, _, err := client.Read(&gotServerAddr); err != nil { + readBuf.Reset() + if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("client.Read(_): %s", err) } else { - if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { - t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: readBuf.Len(), + Total: readBuf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr}, + }, read, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff) } - if gotServerAddr.Addr != serverAddr.Addr { - t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() |