summaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/iptables/filter_output.go170
-rw-r--r--test/iptables/iptables_test.go24
-rw-r--r--test/iptables/iptables_util.go15
-rw-r--r--test/packetimpact/testbench/BUILD1
-rw-r--r--test/packetimpact/testbench/connections.go37
-rw-r--r--test/packetimpact/testbench/dut.go52
-rw-r--r--test/packetimpact/testbench/dut_client.go6
-rw-r--r--test/packetimpact/testbench/layers.go5
-rw-r--r--test/packetimpact/testbench/rawsockets.go6
-rw-r--r--test/packetimpact/testbench/testbench.go63
-rw-r--r--test/packetimpact/tests/BUILD11
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go5
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go5
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go5
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go5
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go5
-rw-r--r--test/packetimpact/tests/tcp_should_piggyback_test.go5
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go5
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go5
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go5
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go5
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go5
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go21
-rw-r--r--test/packetimpact/tests/udp_recv_multicast_test.go5
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go101
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc5
27 files changed, 510 insertions, 68 deletions
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index b1382d353..c145bd1e9 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -29,6 +29,12 @@ func init() {
RegisterTestCase(FilterOutputAcceptUDPOwner{})
RegisterTestCase(FilterOutputDropUDPOwner{})
RegisterTestCase(FilterOutputOwnerFail{})
+ RegisterTestCase(FilterOutputInterfaceAccept{})
+ RegisterTestCase(FilterOutputInterfaceDrop{})
+ RegisterTestCase(FilterOutputInterface{})
+ RegisterTestCase(FilterOutputInterfaceBeginsWith{})
+ RegisterTestCase(FilterOutputInterfaceInvertDrop{})
+ RegisterTestCase(FilterOutputInterfaceInvertAccept{})
}
// FilterOutputDropTCPDestPort tests that connections are not accepted on
@@ -286,3 +292,167 @@ func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error {
func (FilterOutputInvertDestination) LocalAction(ip net.IP) error {
return listenUDP(acceptPort, sendloopDuration)
}
+
+// FilterOutputInterfaceAccept tests that packets are sent via interface
+// matching the iptables rule.
+type FilterOutputInterfaceAccept struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceAccept) Name() string {
+ return "FilterOutputInterfaceAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceAccept) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInterfaceDrop tests that packets are not sent via interface
+// matching the iptables rule.
+type FilterOutputInterfaceDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceDrop) Name() string {
+ return "FilterOutputInterfaceDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error {
+ if err := listenUDP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterface tests that packets are sent via interface which is
+// not matching the interface name in the iptables rule.
+type FilterOutputInterface struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterface) Name() string {
+ return "FilterOutputInterface"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterface) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterface) LocalAction(ip net.IP) error {
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// FilterOutputInterfaceBeginsWith tests that packets are not sent via an
+// interface which begins with the given interface name.
+type FilterOutputInterfaceBeginsWith struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceBeginsWith) Name() string {
+ return "FilterOutputInterfaceBeginsWith"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error {
+ if err := listenUDP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterfaceInvertDrop tests that we selectively do not send
+// packets via interface not matching the interface name.
+type FilterOutputInterfaceInvertDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceInvertDrop) Name() string {
+ return "FilterOutputInterfaceInvertDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ if err := listenTCP(acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error {
+ if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterfaceInvertAccept tests that we can selectively send packets
+// not matching the specific outgoing interface.
+type FilterOutputInterfaceInvertAccept struct{}
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceInvertAccept) Name() string {
+ return "FilterOutputInterfaceInvertAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP) error {
+ return connectTCP(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 63a862d35..84eb75a40 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -167,6 +167,30 @@ func TestFilterOutputOwnerFail(t *testing.T) {
singleTest(t, FilterOutputOwnerFail{})
}
+func TestFilterOutputInterfaceAccept(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceAccept{})
+}
+
+func TestFilterOutputInterfaceDrop(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceDrop{})
+}
+
+func TestFilterOutputInterface(t *testing.T) {
+ singleTest(t, FilterOutputInterface{})
+}
+
+func TestFilterOutputInterfaceBeginsWith(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceBeginsWith{})
+}
+
+func TestFilterOutputInterfaceInvertDrop(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceInvertDrop{})
+}
+
+func TestFilterOutputInterfaceInvertAccept(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceInvertAccept{})
+}
+
func TestJumpSerialize(t *testing.T) {
singleTest(t, FilterInputSerializeJump{})
}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 2f988cd18..7146edbb9 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -169,3 +169,18 @@ func localAddrs() ([]string, error) {
}
return addrStrs, nil
}
+
+// getInterfaceName returns the name of the interface other than loopback.
+func getInterfaceName() (string, bool) {
+ var ifname string
+ if interfaces, err := net.Interfaces(); err == nil {
+ for _, intf := range interfaces {
+ if intf.Name != "lo" {
+ ifname = intf.Name
+ break
+ }
+ }
+ }
+
+ return ifname, ifname != ""
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 3ceceb9d7..fed51006f 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -13,6 +13,7 @@ go_library(
"dut_client.go",
"layers.go",
"rawsockets.go",
+ "testbench.go",
],
deps = [
"//pkg/tcpip",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 56ac3fa54..463fd0556 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -17,7 +17,6 @@
package testbench
import (
- "flag"
"fmt"
"math/rand"
"net"
@@ -32,13 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
-var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets")
-var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets")
-var localIPv6 = flag.String("local_ipv6", "", "local IPv6 address for test packets")
-var remoteIPv6 = flag.String("remote_ipv6", "", "remote IPv6 address for test packets")
-var localMAC = flag.String("local_mac", "", "local mac address for test packets")
-var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets")
-
func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
switch sa := sa.(type) {
case *unix.SockaddrInet4:
@@ -64,11 +56,11 @@ func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) {
switch domain {
case unix.AF_INET:
var sa4 unix.SockaddrInet4
- copy(sa4.Addr[:], net.ParseIP(*localIPv4).To4())
+ copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4())
sa = &sa4
case unix.AF_INET6:
var sa6 unix.SockaddrInet6
- copy(sa6.Addr[:], net.ParseIP(*localIPv6).To16())
+ copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
sa = &sa6
default:
return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
@@ -120,12 +112,12 @@ var _ layerState = (*etherState)(nil)
// newEtherState creates a new etherState.
func newEtherState(out, in Ether) (*etherState, error) {
- lMAC, err := tcpip.ParseMACAddress(*localMAC)
+ lMAC, err := tcpip.ParseMACAddress(LocalMAC)
if err != nil {
return nil, err
}
- rMAC, err := tcpip.ParseMACAddress(*remoteMAC)
+ rMAC, err := tcpip.ParseMACAddress(RemoteMAC)
if err != nil {
return nil, err
}
@@ -172,8 +164,8 @@ var _ layerState = (*ipv4State)(nil)
// newIPv4State creates a new ipv4State.
func newIPv4State(out, in IPv4) (*ipv4State, error) {
- lIP := tcpip.Address(net.ParseIP(*localIPv4).To4())
- rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4())
+ lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4())
+ rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4())
s := ipv4State{
out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
@@ -217,8 +209,8 @@ var _ layerState = (*ipv6State)(nil)
// newIPv6State creates a new ipv6State.
func newIPv6State(out, in IPv6) (*ipv6State, error) {
- lIP := tcpip.Address(net.ParseIP(*localIPv6).To16())
- rIP := tcpip.Address(net.ParseIP(*remoteIPv6).To16())
+ lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16())
+ rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16())
s := ipv6State{
out: IPv6{SrcAddr: &lIP, DstAddr: &rIP},
in: IPv6{SrcAddr: &rIP, DstAddr: &lIP},
@@ -841,6 +833,7 @@ func (conn *UDPIPv4) SendIP(additionalLayers ...Layer) {
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
+ conn.t.Helper()
layer, err := (*Connection)(conn).Expect(&udp, timeout)
if layer == nil {
return nil, err
@@ -852,6 +845,18 @@ func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
return gotUDP, err
}
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ conn.t.Helper()
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = &udp
+ if payload.length() != 0 {
+ expected = append(expected, &payload)
+ }
+ return (*Connection)(conn).ExpectFrame(expected, timeout)
+}
+
// Close frees associated resources held by the UDPIPv4 connection.
func (conn *UDPIPv4) Close() {
(*Connection)(conn).Close()
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index f68d9d62b..a78b7d7ee 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -16,12 +16,10 @@ package testbench
import (
"context"
- "flag"
"net"
"strconv"
"syscall"
"testing"
- "time"
pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
@@ -30,29 +28,21 @@ import (
"google.golang.org/grpc/keepalive"
)
-var (
- posixServerIP = flag.String("posix_server_ip", "", "ip address to listen to for UDP commands")
- posixServerPort = flag.Int("posix_server_port", 40000, "port to listen to for UDP commands")
- rpcTimeout = flag.Duration("rpc_timeout", 100*time.Millisecond, "gRPC timeout")
- rpcKeepalive = flag.Duration("rpc_keepalive", 10*time.Second, "gRPC keepalive")
-)
-
// DUT communicates with the DUT to force it to make POSIX calls.
type DUT struct {
t *testing.T
conn *grpc.ClientConn
- posixServer PosixClient
+ posixServer POSIXClient
}
// NewDUT creates a new connection with the DUT over gRPC.
func NewDUT(t *testing.T) DUT {
- flag.Parse()
- posixServerAddress := *posixServerIP + ":" + strconv.Itoa(*posixServerPort)
- conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: *rpcKeepalive}))
+ posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort)
+ conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
if err != nil {
t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err)
}
- posixServer := NewPosixClient(conn)
+ posixServer := NewPOSIXClient(conn)
return DUT{
t: t,
conn: conn,
@@ -149,12 +139,12 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16)
// CreateListener makes a new TCP connection. If it fails, the test ends.
func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
- fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4))
+ fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4))
dut.Listen(fd, backlog)
return fd, remotePort
}
-// All the functions that make gRPC calls to the Posix service are below, sorted
+// All the functions that make gRPC calls to the POSIX service are below, sorted
// alphabetically.
// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
@@ -162,7 +152,7 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
// AcceptWithErrno.
func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
if fd < 0 {
@@ -189,7 +179,7 @@ func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.
// needed, use BindWithErrno.
func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.BindWithErrno(ctx, fd, sa)
if ret != 0 {
@@ -216,7 +206,7 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (
// CloseWithErrno.
func (dut *DUT) Close(fd int32) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.CloseWithErrno(ctx, fd)
if ret != 0 {
@@ -242,7 +232,7 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
// needed, use ConnectWithErrno.
func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.ConnectWithErrno(ctx, fd, sa)
if ret != 0 {
@@ -269,7 +259,7 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr
// needed, use GetSockNameWithErrno.
func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
if ret != 0 {
@@ -318,7 +308,7 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i
// more specific GetSockOptXxx function.
func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen)
if ret != 0 {
@@ -345,7 +335,7 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname,
// is needed, use GetSockOptIntWithErrno.
func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname)
if ret != 0 {
@@ -370,7 +360,7 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna
// needed, use GetSockOptTimevalWithErrno.
func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname)
if ret != 0 {
@@ -399,7 +389,7 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o
// ListenWithErrno.
func (dut *DUT) Listen(sockfd, backlog int32) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
if ret != 0 {
@@ -426,7 +416,7 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int
// SendWithErrno.
func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
if ret == -1 {
@@ -455,7 +445,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
// SendToWithErrno.
func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr)
if ret == -1 {
@@ -502,7 +492,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op
// more specific SetSockOptXxx function.
func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
if ret != 0 {
@@ -523,7 +513,7 @@ func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname
// is needed, use SetSockOptIntWithErrno.
func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
if ret != 0 {
@@ -542,7 +532,7 @@ func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna
// needed, use SetSockOptTimevalWithErrno.
func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
if ret != 0 {
@@ -593,7 +583,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
// RecvWithErrno.
func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
dut.t.Helper()
- ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout)
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
if ret == -1 {
diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go
index b130a33a2..d0e68c5da 100644
--- a/test/packetimpact/testbench/dut_client.go
+++ b/test/packetimpact/testbench/dut_client.go
@@ -20,9 +20,9 @@ import (
)
// PosixClient is a gRPC client for the Posix service.
-type PosixClient pb.PosixClient
+type POSIXClient pb.PosixClient
-// NewPosixClient makes a new gRPC client for the Posix service.
-func NewPosixClient(c grpc.ClientConnInterface) PosixClient {
+// NewPOSIXClient makes a new gRPC client for the POSIX service.
+func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient {
return pb.NewPosixClient(c)
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 165f62d3b..49370377d 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -898,10 +898,7 @@ func (l *UDP) match(other Layer) bool {
}
func (l *UDP) length() int {
- if l.Length == nil {
- return header.UDPMinimumSize
- }
- return int(*l.Length)
+ return header.UDPMinimumSize
}
// merge implements Layer.merge.
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index ff722d4a6..4665f60b2 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -27,8 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-var device = flag.String("device", "", "local device for test packets")
-
// Sniffer can sniff raw packets on the wire.
type Sniffer struct {
t *testing.T
@@ -139,7 +137,7 @@ type Injector struct {
// NewInjector creates a new injector on *device.
func NewInjector(t *testing.T) (Injector, error) {
flag.Parse()
- ifInfo, err := net.InterfaceByName(*device)
+ ifInfo, err := net.InterfaceByName(Device)
if err != nil {
return Injector{}, err
}
@@ -169,7 +167,7 @@ func NewInjector(t *testing.T) (Injector, error) {
// Send a raw frame.
func (i *Injector) Send(b []byte) {
if _, err := unix.Write(i.fd, b); err != nil {
- i.t.Fatalf("can't write: %s", err)
+ i.t.Fatalf("can't write: %s of len %d", err, len(b))
}
}
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
new file mode 100644
index 000000000..a1242b189
--- /dev/null
+++ b/test/packetimpact/testbench/testbench.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "flag"
+ "time"
+)
+
+var (
+ // Device is the local device on the test network.
+ Device = ""
+ // LocalIPv4 is the local IPv4 address on the test network.
+ LocalIPv4 = ""
+ // LocalIPv6 is the local IPv6 address on the test network.
+ LocalIPv6 = ""
+ // LocalMAC is the local MAC address on the test network.
+ LocalMAC = ""
+ // POSIXServerIP is the POSIX server's IP address on the control network.
+ POSIXServerIP = ""
+ // POSIXServerPort is the UDP port the POSIX server is bound to on the
+ // control network.
+ POSIXServerPort = 40000
+ // RemoteIPv4 is the DUT's IPv4 address on the test network.
+ RemoteIPv4 = ""
+ // RemoteIPv6 is the DUT's IPv6 address on the test network.
+ RemoteIPv6 = ""
+ // RemoteMAC is the DUT's MAC address on the test network.
+ RemoteMAC = ""
+ // RPCKeepalive is the gRPC keepalive.
+ RPCKeepalive = 10 * time.Second
+ // RPCTimeout is the gRPC timeout.
+ RPCTimeout = 100 * time.Millisecond
+)
+
+// RegisterFlags defines flags and associates them with the package-level
+// exported variables above. It should be called by tests in their init
+// functions.
+func RegisterFlags(fs *flag.FlagSet) {
+ fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands")
+ fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands")
+ fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout")
+ fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
+ fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets")
+ fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
+ fs.StringVar(&LocalIPv6, "local_ipv6", LocalIPv6, "local IPv6 address for test packets")
+ fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
+ fs.StringVar(&LocalMAC, "local_mac", LocalMAC, "local mac address for test packets")
+ fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
+ fs.StringVar(&Device, "device", Device, "local device for test packets")
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index e4ced444b..c25b3b8c1 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -117,8 +117,6 @@ packetimpact_go_test(
packetimpact_go_test(
name = "tcp_close_wait_ack",
srcs = ["tcp_close_wait_ack_test.go"],
- # TODO(b/153574037): Fix netstack then remove the line below.
- netstack = False,
deps = [
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
@@ -150,6 +148,15 @@ packetimpact_go_test(
],
)
+packetimpact_go_test(
+ name = "udp_send_recv_dgram",
+ srcs = ["udp_send_recv_dgram_test.go"],
+ deps = [
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
sh_binary(
name = "test_runner",
srcs = ["test_runner.sh"],
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index 99dc77f9a..c26ab78d9 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -15,6 +15,7 @@
package fin_wait2_timeout_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func TestFinWait2Timeout(t *testing.T) {
for _, tt := range []struct {
description string
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
index b48e55df4..bb1fc26fc 100644
--- a/test/packetimpact/tests/icmpv6_param_problem_test.go
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -16,6 +16,7 @@ package icmpv6_param_problem_test
import (
"encoding/binary"
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT
// should respond with an ICMPv6 Parameter Problem message.
func TestICMPv6ParamProblemTest(t *testing.T) {
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
index 153ce285b..70a22a2db 100644
--- a/test/packetimpact/tests/tcp_close_wait_ack_test.go
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -15,6 +15,7 @@
package tcp_close_wait_ack_test
import (
+ "flag"
"fmt"
"testing"
"time"
@@ -25,6 +26,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func TestCloseWaitAck(t *testing.T) {
for _, tt := range []struct {
description string
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
index 7ebdd1950..2c1ec27d3 100644
--- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -15,6 +15,7 @@
package tcp_noaccept_close_rst_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func TestTcpNoAcceptCloseReset(t *testing.T) {
dut := tb.NewDUT(t)
defer dut.TearDown()
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
index db3d3273b..351df193e 100644
--- a/test/packetimpact/tests/tcp_outside_the_window_test.go
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -15,6 +15,7 @@
package tcp_outside_the_window_test
import (
+ "flag"
"fmt"
"testing"
"time"
@@ -25,6 +26,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
// that are inside or outside the TCP window. Packets that are outside the
// window should force an extra ACK, as described in RFC793 page 69:
diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go
index b0be6ba23..0240dc2f9 100644
--- a/test/packetimpact/tests/tcp_should_piggyback_test.go
+++ b/test/packetimpact/tests/tcp_should_piggyback_test.go
@@ -15,6 +15,7 @@
package tcp_should_piggyback_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func TestPiggyback(t *testing.T) {
dut := tb.NewDUT(t)
defer dut.TearDown()
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
index 3cf82badb..ce31917ee 100644
--- a/test/packetimpact/tests/tcp_user_timeout_test.go
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -15,6 +15,7 @@
package tcp_user_timeout_test
import (
+ "flag"
"fmt"
"testing"
"time"
@@ -24,6 +25,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func sendPayload(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error {
sampleData := make([]byte, 100)
for i := range sampleData {
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
index c9354074e..58ec1d740 100644
--- a/test/packetimpact/tests/tcp_window_shrink_test.go
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -15,6 +15,7 @@
package tcp_window_shrink_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func TestWindowShrink(t *testing.T) {
dut := tb.NewDUT(t)
defer dut.TearDown()
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
index 864e5a634..dd43a24db 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -15,6 +15,7 @@
package tcp_zero_window_probe_retransmit_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
// TestZeroWindowProbeRetransmit tests retransmits of zero window probes
// to be sent at exponentially inreasing time intervals.
func TestZeroWindowProbeRetransmit(t *testing.T) {
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
index 4fa3d0cd4..6c453505d 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -15,6 +15,7 @@
package tcp_zero_window_probe_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
// TestZeroWindowProbe tests few cases of zero window probing over the
// same connection.
func TestZeroWindowProbe(t *testing.T) {
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
index 7d81c276c..193427fb9 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -15,6 +15,7 @@
package tcp_zero_window_probe_usertimeout_test
import (
+ "flag"
"testing"
"time"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are
// retransmitting zero window probes.
func TestZeroWindowProbeUserTimeout(t *testing.T) {
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
index c47af9a3e..ca4df2ab0 100644
--- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -16,6 +16,7 @@ package udp_icmp_error_propagation_test
import (
"context"
+ "flag"
"fmt"
"net"
"sync"
@@ -28,6 +29,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
type connectionMode bool
func (c connectionMode) String() string {
@@ -304,16 +309,20 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
+ defer wg.Done()
+
if wantErrno != syscall.Errno(0) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0)
if ret != -1 {
- t.Fatalf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
+ t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
+ return
}
if err != wantErrno {
- t.Fatalf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
+ t.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
+ return
}
}
@@ -321,19 +330,19 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
defer cancel()
if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 {
- t.Fatalf("recv after ICMP error failed with (%[1]d) %[1]", err)
+ t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err)
}
- wg.Done()
}()
go func() {
+ defer wg.Done()
+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 {
- t.Fatalf("recv on clean socket failed with (%[1]d) %[1]", err)
+ t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err)
}
- wg.Done()
}()
// TODO(b/155684889) This sleep is to allow time for the DUT to
diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go
index 61fd17050..0bae18ba3 100644
--- a/test/packetimpact/tests/udp_recv_multicast_test.go
+++ b/test/packetimpact/tests/udp_recv_multicast_test.go
@@ -15,6 +15,7 @@
package udp_recv_multicast_test
import (
+ "flag"
"net"
"testing"
@@ -23,6 +24,10 @@ import (
tb "gvisor.dev/gvisor/test/packetimpact/testbench"
)
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
func TestUDPRecvMulticast(t *testing.T) {
dut := tb.NewDUT(t)
defer dut.TearDown()
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
new file mode 100644
index 000000000..350875a6f
--- /dev/null
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_send_recv_dgram_test
+
+import (
+ "flag"
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
+func generateRandomPayload(t *testing.T, n int) string {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return string(buf)
+}
+
+func TestUDPRecv(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ testCases := []struct {
+ name string
+ payload string
+ }{
+ {"emptypayload", ""},
+ {"small payload", "hello world"},
+ {"1kPayload", generateRandomPayload(t, 1<<10)},
+ // Even though UDP allows larger dgrams we don't test it here as
+ // they need to be fragmented and written out as individual
+ // frames.
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ frame := conn.CreateFrame(&tb.UDP{}, &tb.Payload{Bytes: []byte(tc.payload)})
+ conn.SendFrame(frame)
+ if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), tc.payload; got != want {
+ t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestUDPSend(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(boundFD)
+ conn := tb.NewUDPIPv4(t, tb.UDP{DstPort: &remotePort}, tb.UDP{SrcPort: &remotePort})
+ defer conn.Close()
+
+ testCases := []struct {
+ name string
+ payload string
+ }{
+ {"emptypayload", ""},
+ {"small payload", "hello world"},
+ {"1kPayload", generateRandomPayload(t, 1<<10)},
+ // Even though UDP allows larger dgrams we don't test it here as
+ // they need to be fragmented and written out as individual
+ // frames.
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ conn.Drain()
+ if got, want := int(dut.SendTo(boundFD, []byte(tc.payload), 0, conn.LocalAddr())), len(tc.payload); got != want {
+ t.Fatalf("short write got: %d, want: %d", got, want)
+ }
+ if _, err := conn.ExpectData(tb.UDP{SrcPort: &remotePort}, tb.Payload{Bytes: []byte(tc.payload)}, 1*time.Second); err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 837e56042..adf259bba 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2620,6 +2620,7 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ "@com_google_absl//absl/container:node_hash_map",
gtest,
"//test/util:test_main",
"//test/util:test_util",
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
index 637d1151a..1967329ee 100644
--- a/test/syscalls/linux/socket_bind_to_device_sequence.cc
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -33,6 +33,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_bind_to_device_util.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -192,8 +193,8 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
in_port_t port_ = 0;
// sockets_to_close_ is a map from action index to the socket that was
// created.
- std::unordered_map<int,
- std::unique_ptr<gvisor::testing::FileDescriptor>>
+ absl::node_hash_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
sockets_to_close_;
int next_socket_id_ = 0;
};