diff options
Diffstat (limited to 'test')
22 files changed, 187 insertions, 77 deletions
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 3ceceb9d7..fed51006f 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -13,6 +13,7 @@ go_library( "dut_client.go", "layers.go", "rawsockets.go", + "testbench.go", ], deps = [ "//pkg/tcpip", diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 28c841612..463fd0556 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -17,7 +17,6 @@ package testbench import ( - "flag" "fmt" "math/rand" "net" @@ -32,13 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) -var localIPv4 = flag.String("local_ipv4", "", "local IPv4 address for test packets") -var remoteIPv4 = flag.String("remote_ipv4", "", "remote IPv4 address for test packets") -var localIPv6 = flag.String("local_ipv6", "", "local IPv6 address for test packets") -var remoteIPv6 = flag.String("remote_ipv6", "", "remote IPv6 address for test packets") -var localMAC = flag.String("local_mac", "", "local mac address for test packets") -var remoteMAC = flag.String("remote_mac", "", "remote mac address for test packets") - func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { switch sa := sa.(type) { case *unix.SockaddrInet4: @@ -64,11 +56,11 @@ func pickPort(domain, typ int) (fd int, sa unix.Sockaddr, err error) { switch domain { case unix.AF_INET: var sa4 unix.SockaddrInet4 - copy(sa4.Addr[:], net.ParseIP(*localIPv4).To4()) + copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4()) sa = &sa4 case unix.AF_INET6: var sa6 unix.SockaddrInet6 - copy(sa6.Addr[:], net.ParseIP(*localIPv6).To16()) + copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16()) sa = &sa6 default: return -1, nil, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) @@ -120,12 +112,12 @@ var _ layerState = (*etherState)(nil) // newEtherState creates a new etherState. func newEtherState(out, in Ether) (*etherState, error) { - lMAC, err := tcpip.ParseMACAddress(*localMAC) + lMAC, err := tcpip.ParseMACAddress(LocalMAC) if err != nil { return nil, err } - rMAC, err := tcpip.ParseMACAddress(*remoteMAC) + rMAC, err := tcpip.ParseMACAddress(RemoteMAC) if err != nil { return nil, err } @@ -172,8 +164,8 @@ var _ layerState = (*ipv4State)(nil) // newIPv4State creates a new ipv4State. func newIPv4State(out, in IPv4) (*ipv4State, error) { - lIP := tcpip.Address(net.ParseIP(*localIPv4).To4()) - rIP := tcpip.Address(net.ParseIP(*remoteIPv4).To4()) + lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4()) + rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4()) s := ipv4State{ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, @@ -217,8 +209,8 @@ var _ layerState = (*ipv6State)(nil) // newIPv6State creates a new ipv6State. func newIPv6State(out, in IPv6) (*ipv6State, error) { - lIP := tcpip.Address(net.ParseIP(*localIPv6).To16()) - rIP := tcpip.Address(net.ParseIP(*remoteIPv6).To16()) + lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16()) s := ipv6State{ out: IPv6{SrcAddr: &lIP, DstAddr: &rIP}, in: IPv6{SrcAddr: &rIP, DstAddr: &lIP}, diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index f68d9d62b..a78b7d7ee 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -16,12 +16,10 @@ package testbench import ( "context" - "flag" "net" "strconv" "syscall" "testing" - "time" pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" @@ -30,29 +28,21 @@ import ( "google.golang.org/grpc/keepalive" ) -var ( - posixServerIP = flag.String("posix_server_ip", "", "ip address to listen to for UDP commands") - posixServerPort = flag.Int("posix_server_port", 40000, "port to listen to for UDP commands") - rpcTimeout = flag.Duration("rpc_timeout", 100*time.Millisecond, "gRPC timeout") - rpcKeepalive = flag.Duration("rpc_keepalive", 10*time.Second, "gRPC keepalive") -) - // DUT communicates with the DUT to force it to make POSIX calls. type DUT struct { t *testing.T conn *grpc.ClientConn - posixServer PosixClient + posixServer POSIXClient } // NewDUT creates a new connection with the DUT over gRPC. func NewDUT(t *testing.T) DUT { - flag.Parse() - posixServerAddress := *posixServerIP + ":" + strconv.Itoa(*posixServerPort) - conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: *rpcKeepalive})) + posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort) + conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive})) if err != nil { t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err) } - posixServer := NewPosixClient(conn) + posixServer := NewPOSIXClient(conn) return DUT{ t: t, conn: conn, @@ -149,12 +139,12 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) // CreateListener makes a new TCP connection. If it fails, the test ends. func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { - fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(*remoteIPv4)) + fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4)) dut.Listen(fd, backlog) return fd, remotePort } -// All the functions that make gRPC calls to the Posix service are below, sorted +// All the functions that make gRPC calls to the POSIX service are below, sorted // alphabetically. // Accept calls accept on the DUT and causes a fatal test failure if it doesn't @@ -162,7 +152,7 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { // AcceptWithErrno. func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) if fd < 0 { @@ -189,7 +179,7 @@ func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix. // needed, use BindWithErrno. func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.BindWithErrno(ctx, fd, sa) if ret != 0 { @@ -216,7 +206,7 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) ( // CloseWithErrno. func (dut *DUT) Close(fd int32) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.CloseWithErrno(ctx, fd) if ret != 0 { @@ -242,7 +232,7 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { // needed, use ConnectWithErrno. func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.ConnectWithErrno(ctx, fd, sa) if ret != 0 { @@ -269,7 +259,7 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr // needed, use GetSockNameWithErrno. func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) if ret != 0 { @@ -318,7 +308,7 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i // more specific GetSockOptXxx function. func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen) if ret != 0 { @@ -345,7 +335,7 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, // is needed, use GetSockOptIntWithErrno. func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname) if ret != 0 { @@ -370,7 +360,7 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna // needed, use GetSockOptTimevalWithErrno. func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname) if ret != 0 { @@ -399,7 +389,7 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o // ListenWithErrno. func (dut *DUT) Listen(sockfd, backlog int32) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) if ret != 0 { @@ -426,7 +416,7 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int // SendWithErrno. func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) if ret == -1 { @@ -455,7 +445,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla // SendToWithErrno. func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr) if ret == -1 { @@ -502,7 +492,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op // more specific SetSockOptXxx function. func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) if ret != 0 { @@ -523,7 +513,7 @@ func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname // is needed, use SetSockOptIntWithErrno. func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) if ret != 0 { @@ -542,7 +532,7 @@ func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna // needed, use SetSockOptTimevalWithErrno. func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) if ret != 0 { @@ -593,7 +583,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { // RecvWithErrno. func (dut *DUT) Recv(sockfd, len, flags int32) []byte { dut.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), *rpcTimeout) + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) if ret == -1 { diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go index b130a33a2..d0e68c5da 100644 --- a/test/packetimpact/testbench/dut_client.go +++ b/test/packetimpact/testbench/dut_client.go @@ -20,9 +20,9 @@ import ( ) // PosixClient is a gRPC client for the Posix service. -type PosixClient pb.PosixClient +type POSIXClient pb.PosixClient -// NewPosixClient makes a new gRPC client for the Posix service. -func NewPosixClient(c grpc.ClientConnInterface) PosixClient { +// NewPOSIXClient makes a new gRPC client for the POSIX service. +func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient { return pb.NewPosixClient(c) } diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index a9ad72b63..4665f60b2 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -27,8 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -var device = flag.String("device", "", "local device for test packets") - // Sniffer can sniff raw packets on the wire. type Sniffer struct { t *testing.T @@ -139,7 +137,7 @@ type Injector struct { // NewInjector creates a new injector on *device. func NewInjector(t *testing.T) (Injector, error) { flag.Parse() - ifInfo, err := net.InterfaceByName(*device) + ifInfo, err := net.InterfaceByName(Device) if err != nil { return Injector{}, err } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go new file mode 100644 index 000000000..a1242b189 --- /dev/null +++ b/test/packetimpact/testbench/testbench.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testbench + +import ( + "flag" + "time" +) + +var ( + // Device is the local device on the test network. + Device = "" + // LocalIPv4 is the local IPv4 address on the test network. + LocalIPv4 = "" + // LocalIPv6 is the local IPv6 address on the test network. + LocalIPv6 = "" + // LocalMAC is the local MAC address on the test network. + LocalMAC = "" + // POSIXServerIP is the POSIX server's IP address on the control network. + POSIXServerIP = "" + // POSIXServerPort is the UDP port the POSIX server is bound to on the + // control network. + POSIXServerPort = 40000 + // RemoteIPv4 is the DUT's IPv4 address on the test network. + RemoteIPv4 = "" + // RemoteIPv6 is the DUT's IPv6 address on the test network. + RemoteIPv6 = "" + // RemoteMAC is the DUT's MAC address on the test network. + RemoteMAC = "" + // RPCKeepalive is the gRPC keepalive. + RPCKeepalive = 10 * time.Second + // RPCTimeout is the gRPC timeout. + RPCTimeout = 100 * time.Millisecond +) + +// RegisterFlags defines flags and associates them with the package-level +// exported variables above. It should be called by tests in their init +// functions. +func RegisterFlags(fs *flag.FlagSet) { + fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands") + fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands") + fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout") + fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive") + fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets") + fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") + fs.StringVar(&LocalIPv6, "local_ipv6", LocalIPv6, "local IPv6 address for test packets") + fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") + fs.StringVar(&LocalMAC, "local_mac", LocalMAC, "local mac address for test packets") + fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") + fs.StringVar(&Device, "device", Device, "local device for test packets") +} diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 99dc77f9a..c26ab78d9 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -15,6 +15,7 @@ package fin_wait2_timeout_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func TestFinWait2Timeout(t *testing.T) { for _, tt := range []struct { description string diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go index b48e55df4..bb1fc26fc 100644 --- a/test/packetimpact/tests/icmpv6_param_problem_test.go +++ b/test/packetimpact/tests/icmpv6_param_problem_test.go @@ -16,6 +16,7 @@ package icmpv6_param_problem_test import ( "encoding/binary" + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + // TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT // should respond with an ICMPv6 Parameter Problem message. func TestICMPv6ParamProblemTest(t *testing.T) { diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go index 153ce285b..70a22a2db 100644 --- a/test/packetimpact/tests/tcp_close_wait_ack_test.go +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -15,6 +15,7 @@ package tcp_close_wait_ack_test import ( + "flag" "fmt" "testing" "time" @@ -25,6 +26,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func TestCloseWaitAck(t *testing.T) { for _, tt := range []struct { description string diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index 7ebdd1950..2c1ec27d3 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -15,6 +15,7 @@ package tcp_noaccept_close_rst_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func TestTcpNoAcceptCloseReset(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index db3d3273b..351df193e 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -15,6 +15,7 @@ package tcp_outside_the_window_test import ( + "flag" "fmt" "testing" "time" @@ -25,6 +26,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + // TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive // that are inside or outside the TCP window. Packets that are outside the // window should force an extra ACK, as described in RFC793 page 69: diff --git a/test/packetimpact/tests/tcp_should_piggyback_test.go b/test/packetimpact/tests/tcp_should_piggyback_test.go index b0be6ba23..0240dc2f9 100644 --- a/test/packetimpact/tests/tcp_should_piggyback_test.go +++ b/test/packetimpact/tests/tcp_should_piggyback_test.go @@ -15,6 +15,7 @@ package tcp_should_piggyback_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func TestPiggyback(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go index 3cf82badb..ce31917ee 100644 --- a/test/packetimpact/tests/tcp_user_timeout_test.go +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -15,6 +15,7 @@ package tcp_user_timeout_test import ( + "flag" "fmt" "testing" "time" @@ -24,6 +25,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func sendPayload(conn *tb.TCPIPv4, dut *tb.DUT, fd int32) error { sampleData := make([]byte, 100) for i := range sampleData { diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index c9354074e..58ec1d740 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -15,6 +15,7 @@ package tcp_window_shrink_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func TestWindowShrink(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index 864e5a634..dd43a24db 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -15,6 +15,7 @@ package tcp_zero_window_probe_retransmit_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + // TestZeroWindowProbeRetransmit tests retransmits of zero window probes // to be sent at exponentially inreasing time intervals. func TestZeroWindowProbeRetransmit(t *testing.T) { diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go index 4fa3d0cd4..6c453505d 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -15,6 +15,7 @@ package tcp_zero_window_probe_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + // TestZeroWindowProbe tests few cases of zero window probing over the // same connection. func TestZeroWindowProbe(t *testing.T) { diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go index 7d81c276c..193427fb9 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -15,6 +15,7 @@ package tcp_zero_window_probe_usertimeout_test import ( + "flag" "testing" "time" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + // TestZeroWindowProbeUserTimeout sanity tests user timeout when we are // retransmitting zero window probes. func TestZeroWindowProbeUserTimeout(t *testing.T) { diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index 30dcb336e..ca4df2ab0 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -16,8 +16,10 @@ package udp_icmp_error_propagation_test import ( "context" + "flag" "fmt" "net" + "sync" "syscall" "testing" "time" @@ -27,6 +29,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + type connectionMode bool func (c connectionMode) String() string { @@ -300,19 +306,22 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { t.Fatalf("did not receive message from DUT: %s", err) } - c := make(chan error) - go func(c chan error) { + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + if wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0) if ret != -1 { - c <- fmt.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) + t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) return } if err != wantErrno { - c <- fmt.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno) + t.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno) return } } @@ -321,23 +330,20 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { defer cancel() if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 { - c <- fmt.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) - return + t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) } - c <- nil - }(c) + }() + + go func() { + defer wg.Done() - cleanChan := make(chan error) - go func(c chan error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 { - c <- fmt.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) - return + t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) } - c <- nil - }(cleanChan) + }() // TODO(b/155684889) This sleep is to allow time for the DUT to // actually call recv since we want the ICMP error to arrive during the @@ -351,14 +357,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { conn.Send(tb.UDP{DstPort: &cleanPort}) conn.Send(tb.UDP{}) - - err, errClean := <-c, <-cleanChan - if errClean != nil { - t.Error(err) - } - if err != nil { - t.Fatal(err) - } + wg.Wait() }) } } diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_recv_multicast_test.go index 61fd17050..0bae18ba3 100644 --- a/test/packetimpact/tests/udp_recv_multicast_test.go +++ b/test/packetimpact/tests/udp_recv_multicast_test.go @@ -15,6 +15,7 @@ package udp_recv_multicast_test import ( + "flag" "net" "testing" @@ -23,6 +24,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func TestUDPRecvMulticast(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 1c682831c..350875a6f 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -15,6 +15,7 @@ package udp_send_recv_dgram_test import ( + "flag" "math/rand" "net" "testing" @@ -24,6 +25,10 @@ import ( tb "gvisor.dev/gvisor/test/packetimpact/testbench" ) +func init() { + tb.RegisterFlags(flag.CommandLine) +} + func generateRandomPayload(t *testing.T, n int) string { t.Helper() buf := make([]byte, n) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 837e56042..adf259bba 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2620,6 +2620,7 @@ cc_binary( ":socket_bind_to_device_util", ":socket_test_util", "//test/util:capability_util", + "@com_google_absl//absl/container:node_hash_map", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index 637d1151a..1967329ee 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -33,6 +33,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_map.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_bind_to_device_util.h" #include "test/syscalls/linux/socket_test_util.h" @@ -192,8 +193,8 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { in_port_t port_ = 0; // sockets_to_close_ is a map from action index to the socket that was // created. - std::unordered_map<int, - std::unique_ptr<gvisor::testing::FileDescriptor>> + absl::node_hash_map<int, + std::unique_ptr<gvisor::testing::FileDescriptor>> sockets_to_close_; int next_socket_id_ = 0; }; |